From 540c5670b74b50ec22db472dd91784a962c59d1f Mon Sep 17 00:00:00 2001 From: Mike Mao Date: Mon, 12 Aug 2019 06:58:48 +0000 Subject: [PATCH 1/3] add symbolic indexing, waiting for split and concatenate to be merged --- python/mxnet/__init__.py | 1 + python/mxnet/_ctypes/ndarray.py | 4 +- python/mxnet/_ctypes/symbol.py | 7 +- python/mxnet/base.py | 10 + python/mxnet/cython/ndarray.pyx | 4 +- python/mxnet/cython/symbol.pyx | 11 +- python/mxnet/gluon/block.py | 9 +- python/mxnet/ndarray/register.py | 7 +- python/mxnet/symbol/numpy/_symbol.py | 241 ++++++++++++++++++++-- python/mxnet/symbol/register.py | 11 +- src/operator/numpy/np_matrix_op-inl.h | 32 ++- src/operator/numpy/np_matrix_op.cc | 180 +++++++++++++++- src/operator/numpy/np_matrix_op.cu | 3 + src/operator/tensor/matrix_op.cc | 1 + tests/python/unittest/test_numpy_gluon.py | 164 +++++++++++++++ tests/python/unittest/test_numpy_op.py | 55 +++++ 16 files changed, 703 insertions(+), 37 deletions(-) create mode 100644 tests/python/unittest/test_numpy_gluon.py diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index e9c1229d7f2f..d6e63af02a5b 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -38,6 +38,7 @@ from . import name # use mx.sym as short for symbol from . import symbol as sym +from .symbol.numpy import _symbol as np_symbol from . import symbol from . import symbol_doc from . import io diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py index b1a38c1d2621..7807bc42cd2c 100644 --- a/python/mxnet/_ctypes/ndarray.py +++ b/python/mxnet/_ctypes/ndarray.py @@ -70,7 +70,7 @@ def _set_np_ndarray_class(cls): _np_ndarray_cls = cls -def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op): +def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op, output_is_list): """ctypes implementation of imperative invoke wrapper""" if out is not None: original_output = out @@ -102,7 +102,7 @@ def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op): create_ndarray_fn = _np_ndarray_cls if is_np_op else _ndarray_cls if original_output is not None: return original_output - if num_output.value == 1: + if num_output.value == 1 and not output_is_list: return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle), stype=out_stypes[0]) else: diff --git a/python/mxnet/_ctypes/symbol.py b/python/mxnet/_ctypes/symbol.py index 01ba18b38963..e1618f42b03a 100644 --- a/python/mxnet/_ctypes/symbol.py +++ b/python/mxnet/_ctypes/symbol.py @@ -123,7 +123,7 @@ def _set_np_symbol_class(cls): _np_symbol_cls = cls -def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op): +def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op, output_is_list): sym_handle = SymbolHandle() check_call(_LIB.MXSymbolCreateAtomicSymbol( ctypes.c_void_p(handle), @@ -138,6 +138,11 @@ def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op): 'Symbols either as positional or keyword arguments, not both') create_symbol_fn = _np_symbol_cls if is_np_op else _symbol_cls s = create_symbol_fn(sym_handle) + if is_np_op: + if output_is_list: + s._output_is_list = True + else: + s._output_is_list = False if args: s._compose(*args, name=name) elif kwargs: diff --git a/python/mxnet/base.py b/python/mxnet/base.py index dd5fcf0f6db9..d681d23e0195 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -755,12 +755,22 @@ def write_all_str(module_file, module_all_list): _NP_INTERNAL_OP_PREFIX = '_npi_' +_NP_OUTPUT_IS_LIST_OPERATORS = ['npi_split'] + def _is_np_op(op_name): return op_name.startswith(_NP_OP_PREFIX) or op_name.startswith(_NP_EXT_OP_PREFIX)\ or op_name.startswith(_NP_INTERNAL_OP_PREFIX) +def _output_is_list(op_name): + if _is_np_op(op_name): + for target_operator_name in _NP_OUTPUT_IS_LIST_OPERATORS: + if target_operator_name in op_name: + return True + return False + + def _get_op_submodule_name(op_name, op_name_prefix, submodule_name_list): """Get the submodule name of a specific op""" assert op_name.startswith(op_name_prefix) diff --git a/python/mxnet/cython/ndarray.pyx b/python/mxnet/cython/ndarray.pyx index 50791e9b9a86..6d663e678f1a 100644 --- a/python/mxnet/cython/ndarray.pyx +++ b/python/mxnet/cython/ndarray.pyx @@ -170,7 +170,7 @@ cdef class CachedOp: return [NewArray(p_output_vars[i], p_output_stypes[i], self.is_np_sym) for i in range(num_output)] -def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op=0): +def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op=0, output_is_list=0): """cython implementation of imperative invoke wrapper""" cdef unsigned long long ihandle = handle cdef OpHandle chandle = ihandle @@ -221,7 +221,7 @@ def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op=0): if original_output is not None: return original_output - if num_output == 1: + if num_output == 1 and not output_is_list: return NewArray(p_output_vars[0], p_output_stypes[0], is_np_op) else: return [NewArray(p_output_vars[i], p_output_stypes[i], is_np_op) for i in range(num_output)] diff --git a/python/mxnet/cython/symbol.pyx b/python/mxnet/cython/symbol.pyx index 86fe8ae6db4f..1dab47df7f4e 100644 --- a/python/mxnet/cython/symbol.pyx +++ b/python/mxnet/cython/symbol.pyx @@ -96,15 +96,20 @@ def _set_np_symbol_class(cls): _np_symbol_cls = cls -cdef NewSymbol(SymbolHandle handle, int is_np_sym=0): +cdef NewSymbol(SymbolHandle handle, int is_np_sym=0, int output_is_list=0): """Create a new symbol given handle""" create_symbol_fn = _np_symbol_cls if is_np_sym else _symbol_cls sym = create_symbol_fn(None) + if is_np_sym: + if output_is_list: + sym._output_is_list = True + else: + sym._output_is_list = False (sym).chandle = handle return sym -def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op=0): +def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op=0, output_is_list=0): cdef unsigned long long ihandle = handle cdef OpHandle chandle = ihandle cdef vector[string] ckeys @@ -151,4 +156,4 @@ def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op=0): &csym_keys[0] if csym_keys.size() != 0 else NULL, &sym_args[0] if sym_args.size() != 0 else NULL)) - return NewSymbol(ret_handle, is_np_op) + return NewSymbol(ret_handle, is_np_op, output_is_list) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 97e6e8b68453..68a2d09f3d40 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -27,7 +27,7 @@ from collections import OrderedDict from ..base import mx_real_t, MXNetError -from .. import symbol, ndarray, initializer +from .. import symbol, ndarray, initializer, np_symbol from ..symbol import Symbol from ..ndarray import NDArray from .. import name as _name @@ -1055,7 +1055,10 @@ def imports(symbol_file, input_names, param_file=None, ctx=None): ... 'net1-symbol.json', ['data'], 'net1-0001.params') >>> out2 = net2(x) """ - sym = symbol.load(symbol_file) + if is_np_array(): + sym = np_symbol.load(symbol_file) + else: + sym = symbol.load(symbol_file) if isinstance(input_names, str): input_names = [input_names] if param_file is None: @@ -1063,7 +1066,7 @@ def imports(symbol_file, input_names, param_file=None, ctx=None): inputs = [symbol.var(i, dtype=mx_real_t) for i in input_names] else: # Do not specify type, rely on saved params type instead - inputs = [symbol.var(i) for i in input_names] + inputs = [symbol.var(i).as_np_ndarray() if is_np_array() else symbol.var(i) for i in input_names] ret = SymbolBlock(sym, inputs) if param_file is not None: ret.collect_params().load(param_file, ctx=ctx, cast_dtype=True, dtype_source='saved') diff --git a/python/mxnet/ndarray/register.py b/python/mxnet/ndarray/register.py index bdbfa1584ca6..c2f595aa1c83 100644 --- a/python/mxnet/ndarray/register.py +++ b/python/mxnet/ndarray/register.py @@ -24,7 +24,7 @@ from ._internal import NDArrayBase, _imperative_invoke # pylint: disable=unused-import from ..ndarray_doc import _build_doc -from ..base import mx_uint, check_call, _LIB, py_str, _init_op_module, _Null, _is_np_op # pylint: disable=unused-import +from ..base import mx_uint, check_call, _LIB, py_str, _init_op_module, _Null, _is_np_op, _output_is_list # pylint: disable=unused-import from ..util import use_np_shape # pylint: disable=unused-import @@ -176,6 +176,7 @@ def _generate_ndarray_function_code(handle, op_name, func_name, signature_only=F code = [] is_np_op = _is_np_op(op_name) + output_is_list = _output_is_list(op_name) doc_str_idx = 1 if is_np_op: doc_str_idx = 2 @@ -241,8 +242,8 @@ def %s(%s):"""%(func_name, ', '.join(signature))) {verify_fn}("{op_name}", "{func_name}", ndargs, out) """.format(verify_fn=verify_ndarrays_fn, op_name=op_name, func_name=func_name)) code.append(""" - return _imperative_invoke(%d, ndargs, keys, vals, out, %s)"""%( - handle.value, str(is_np_op))) + return _imperative_invoke(%d, ndargs, keys, vals, out, %s, %s)"""%( + handle.value, str(is_np_op), str(output_is_list))) else: code.append(""" return (0,)""") diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index a6699d60871a..e9385d9d1a77 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -20,14 +20,21 @@ from __future__ import absolute_import import ctypes +import json import numpy as _np from . import _op as _mx_np_op -from ...base import _LIB, SymbolHandle, numeric_types, mx_uint +from ...base import _LIB, SymbolHandle, numeric_types, mx_uint, integer_types, string_types +from ...base import c_str, c_handle_array +from ...base import py_str from ...util import check_call, set_module from ...context import current_context from ..symbol import Symbol from .._internal import _set_np_symbol_class from . import _internal as _npi +try: + from __builtin__ import slice as py_slice +except ImportError: + from builtins import slice as py_slice __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate'] @@ -39,25 +46,90 @@ def _num_outputs(sym): @set_module('mxnet.symbol.numpy') class _Symbol(Symbol): - def __getitem__(self, key): - num_outputs = _num_outputs(self) - if num_outputs == 1: - raise NotImplementedError - if not isinstance(key, int): + def __init__(self, handle): + super(_Symbol, self).__init__(handle) + self._output_is_list = False + + def __getitem__(self, key): # pylint: disable = too-many-return-statements, inconsistent-return-statements + num_outputs = len(self) + # print("Num of outputs is ", num_outputs) + if num_outputs == 1: # pylint: disable = too-many-nested-blocks + # If number of output is one and is not a list, perform ndarray basic slicing + if not self._output_is_list: + if isinstance(key, integer_types): + sliced = _npi.slice(self, key, key+1) + return _npi.reshape(sliced, (-3, -4)) + elif isinstance(key, py_slice): + if key.step is None or key.step != 0: + start = [None] if key.start is None else key.start + stop = [None] if key.stop is None else key.stop + return _npi.slice(self, start, stop, key.step) + else: + raise ValueError("slice step cannot be zero") + elif isinstance(key, list): + raise NotImplementedError + elif isinstance(key, tuple): + begin = [] + end = [] + step = [] + new_shape = () + for index in key: + if isinstance(index, py_slice): + if index.step is not None and index.step == 0: + raise ValueError("slice step cannot be zero") + begin.append(index.start) + end.append(index.stop) + step.append(index.step) + new_shape += (-2,) + elif isinstance(index, integer_types): + begin.append(index) + end.append(index+1) + step.append(1) + new_shape += (-3,) + new_shape += (-4,) + sliced = _npi.slice(self, begin, end, step) + return _npi.reshape(sliced, new_shape) + # perform trivial list slicing on length one list represented by flag + else: + if isinstance(key, integer_types): + if key in [-1, 0]: + self._output_is_list = False + return self + else: + raise IndexError + elif isinstance(key, py_slice): + if (key.start is None or key.start <= 0) and (key.stop is None or key.stop > 0): + return self + else: + raise ValueError + else: + raise IndexError + # list slicing on several nodes of outputs + elif num_outputs > 1: + if isinstance(key, py_slice): + start = 0 if key.start is None else key.start + stop = num_outputs if key.stop is None else key.stop + step = 1 if key.step is None else key.step + return Group([self[i] for i in range(start, stop, step)], _Symbol) + elif isinstance(key, integer_types): + if key >= num_outputs: + # Important, python determines the end by this exception + raise IndexError + handle = SymbolHandle() + check_call(_LIB.MXSymbolGetOutput( + self.handle, mx_uint(key), ctypes.byref(handle))) + return _Symbol(handle=handle) + else: + raise NotImplementedError + else: raise NotImplementedError - if key >= num_outputs: - # Important, python determines the end by this exception - raise IndexError - handle = SymbolHandle() - check_call(_LIB.MXSymbolGetOutput( - self.handle, mx_uint(key), ctypes.byref(handle))) - return _Symbol(handle=handle) + def __setitem__(self, key, value): raise NotImplementedError def __iter__(self): - raise AttributeError('_Symbol object has no attribute __iter__') + return (self[i] for i in range(len(self))) def __add__(self, other): """x.__add__(y) <=> x + y""" @@ -190,7 +262,9 @@ def __le__(self, other): raise TypeError("_Symbol does not support type {} as operand".format(str(type(other)))) def __len__(self): - raise NotImplementedError + output_count = mx_uint() + check_call(_LIB.MXSymbolGetNumOutputs(self.handle, ctypes.byref(output_count))) + return output_count.value def as_nd_ndarray(self): """Convert _Symbol to mxnet.symbol.Symbol to use its convenience fluent methods.""" @@ -858,6 +932,52 @@ def broadcast_to(self, *args, **kwargs): def broadcast_like(self, *args, **kwargs): raise AttributeError('_Symbol object has no attribute broadcast_like') + def save(self, fname, remove_amp_cast=True): + """Saves symbol to a file. + You can also use pickle to do the job if you only work on python. + The advantage of `load`/`save` functions is that the file contents are language agnostic. + This means the model saved by one language binding can be loaded by a different + language binding of `MXNet`. + You also get the benefit of being able to directly load/save from cloud storage(S3, HDFS). + Parameters + ---------- + fname : str + The name of the file. + - "s3://my-bucket/path/my-s3-symbol" + - "hdfs://my-bucket/path/my-hdfs-symbol" + - "/path-to/my-local-symbol" + remove_amp_cast : bool, optional + Whether to remove the amp_cast and amp_multicast operators, before saving the model. + See Also + -------- + symbol.load : Used to load symbol from file. + """ + if not isinstance(fname, string_types): + raise TypeError('fname need to be string') + + handle = self.handle + if remove_amp_cast: + handle = SymbolHandle() + check_call(_LIB.MXSymbolRemoveAmpCast(self.handle, ctypes.byref(handle))) + + processed_symbol = _Symbol(handle) + json_str = processed_symbol.save_json_string() + json_data = json.loads(json_str) + with open(fname, 'w') as file_out: + json.dump(json_data, file_out, indent=2, sort_keys=True) + + def save_json_string(self): + """Saves symbol to a JSON string. + See Also + -------- + symbol.load_json : Used to load symbol from JSON string. + """ + json_str = ctypes.c_char_p() + check_call(_LIB.MXSymbolSaveToJSON(self.handle, ctypes.byref(json_str))) + json_data = json.loads(py_str(json_str.value)) + json_data["output_is_list"] = self._output_is_list + return json.dumps(json_data) + @set_module('mxnet.symbol.numpy') def zeros(shape, dtype=_np.float32, order='C', ctx=None): @@ -1335,4 +1455,95 @@ def concatenate(seq, axis=0, out=None): return _npi.concatenate(*seq, dim=axis, out=out) +def Group(symbols, create_fn=_Symbol): + """Creates a symbol that contains a collection of other symbols, grouped together. + A classic symbol (`mx.sym.Symbol`) will be returned if all the symbols in the list + are of that type; a numpy symbol (`mx.sym.np._Symbol`) will be returned if all the + symbols in the list are of that type. A type error will be raised if a list of mixed + classic and numpy symbols are provided. + Example + ------- + >>> a = mx.sym.Variable('a') + >>> b = mx.sym.Variable('b') + >>> mx.sym.Group([a,b]) + + Parameters + ---------- + symbols : list + List of symbols to be grouped. + create_fn : mx.sym.Symbol or mx.sym.np._Symbol + Symbol class for creating the grouped symbol. + Returns + ------- + sym : Symbol + A group symbol. + """ + if not symbols or any(not isinstance(sym, Symbol) for sym in symbols): + raise TypeError('Expected a list of symbols as input') + handle = SymbolHandle() + check_call(_LIB.MXSymbolCreateGroup( + mx_uint(len(symbols)), + c_handle_array(symbols), ctypes.byref(handle))) + self = create_fn(handle) + self._output_is_list = True #pylint: disable = protected-access + return self + + +@set_module('mxnet.symbol.numpy') +def load_json_string(json_str): + """ + Loads symbol from json string. + Parameters + ---------- + json_str : str + A JSON string. + Returns + ------- + sym : Symbol + The loaded symbol. + See Also + -------- + Symbol.tojson : Used to save symbol into json string. + """ + if not isinstance(json_str, string_types): + raise TypeError('fname required to be string') + handle = SymbolHandle() + json_data = json.loads(json_str) + output_is_list = json_data["output_is_list"] + del json_data["output_is_list"] + check_call(_LIB.MXSymbolCreateFromJSON(c_str(json.dumps(json_data)), ctypes.byref(handle))) + s = _Symbol(handle) + s._output_is_list = output_is_list #pylint: disable = protected-access + return s + + +@set_module('mxnet.symbol.numpy') +def load(fname): + """Loads symbol from a JSON file. + You can also use pickle to do the job if you only work on python. + The advantage of load/save is the file is language agnostic. + This means the file saved using save can be loaded by other language binding of mxnet. + You also get the benefit being able to directly load/save from cloud storage(S3, HDFS). + Parameters + ---------- + fname : str + The name of the file, examples: + - `s3://my-bucket/path/my-s3-symbol` + - `hdfs://my-bucket/path/my-hdfs-symbol` + - `/path-to/my-local-symbol` + Returns + ------- + sym : Symbol + The loaded symbol. + See Also + -------- + Symbol.save : Used to save symbol into file. + """ + if not isinstance(fname, string_types): + raise TypeError('fname need to be string') + with open(fname, 'r') as file_input: + json_data = json.load(file_input) + + return load_json_string(json.dumps(json_data)) + _set_np_symbol_class(_Symbol) diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py index a17dd79048d4..de5c44914302 100644 --- a/python/mxnet/symbol/register.py +++ b/python/mxnet/symbol/register.py @@ -27,7 +27,7 @@ from ..attribute import AttrScope from ..base import mx_uint, check_call, _LIB, py_str from ..symbol_doc import _build_doc -from ..base import _Null, _init_op_module, _is_np_op +from ..base import _Null, _init_op_module, _is_np_op, _output_is_list from ..name import NameManager # pylint: enable=unused-import @@ -144,6 +144,7 @@ def _generate_symbol_function_code(handle, op_name, func_name, signature_only=Fa signature = ndsignature + signature is_np_op = _is_np_op(op_name) + output_is_list = _output_is_list(op_name) verify_symbol_fn = _verify_np_symbol.__name__ if is_np_op else _verify_legacy_symbol.__name__ code = [] if arr_name: @@ -191,8 +192,8 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name)) key_var_num_args, key_var_num_args)) code.append(""" - return _symbol_creator(%d, sym_args, sym_kwargs, keys, vals, name, %s)"""%( - handle.value, str(is_np_op))) + return _symbol_creator(%d, sym_args, sym_kwargs, keys, vals, name, %s, %s)"""%( + handle.value, str(is_np_op), str(output_is_list))) else: code.append(""" def %s(%s):"""%(func_name, ', '.join(signature))) @@ -244,8 +245,8 @@ def %s(%s):"""%(func_name, ', '.join(signature))) if not hasattr(NameManager._current, "value"): NameManager._current.value = NameManager() name = NameManager._current.value.get(name, '%s') - return _symbol_creator(%d, None, sym_kwargs, _keys, _vals, name, %s)"""%( - func_name.lower(), handle.value, str(is_np_op))) + return _symbol_creator(%d, None, sym_kwargs, _keys, _vals, name, %s, %s)"""%( + func_name.lower(), handle.value, str(is_np_op), str(output_is_list))) if signature_only: code.append(""" diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 6d3d9ea5ec85..415a491c3f5d 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -26,8 +26,10 @@ #define MXNET_OPERATOR_NUMPY_NP_MATRIX_OP_INL_H_ #include -#include "../tensor/matrix_op-inl.h" +#include #include "../nn/concat-inl.h" +#include "../tensor/matrix_op-inl.h" +#include "np_broadcast_reduce_op.h" namespace mxnet { namespace op { @@ -60,6 +62,34 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs, } } +struct NumpyXReshapeParam : public dmlc::Parameter { + mxnet::Tuple newshape; + std::string order; + DMLC_DECLARE_PARAMETER(NumpyXReshapeParam) { + DMLC_DECLARE_FIELD(newshape) + .set_default(mxnet::Tuple()) + .describe("The new shape should be compatible with the original shape." + " If an integer, then the result will be a 1-D array of that length." + " One shape dimension can be -1. In this case, the value is inferred" + " from the length of the array and remaining dimensions." + " -2 to -6 are used for data manipulation" + " -2 copy this dimension from the input to the output shape" + " -3 will skip current dimension if and only if the current dim size is one" + " -4 copy all remain of the input dimensions to the output shape" + " -5 use the product of two consecutive dimensions of the input" + " shape as the output" + " -6 split one dimension of the input into two dimensions passed" + " subsequent to -6 in the new shape"); + DMLC_DECLARE_FIELD(order) + .set_default("C") + .describe("Read the elements of a using this index order, and place the elements into" + " the reshaped array using this index order. 'C' means to read/write the elements" + " using C-like index order, with the last axis index changing fastest, back to the" + " first axis index changing slowest. Note that currently only C-like order is" + " supported"); + } +}; + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 73340981037d..a99af063be2c 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -24,6 +24,7 @@ */ #include "./np_matrix_op-inl.h" +#include "../nn/concat-inl.h" namespace mxnet { namespace op { @@ -135,17 +136,165 @@ bool NumpyReshapeInferShape(const mxnet::TShape& src, mxnet::TShape* dst) { } } +bool NumpyXReshapeInferShape(const mxnet::TShape& src, + const mxnet::Tuple& target, + mxnet::TShape* output) { + bool target_shape_is_known = true; + dim_t target_size = 1; + for (int i = 0; i < target.ndim(); ++i) { + if (target[i] < 0) { + target_shape_is_known = false; + target_size = -1; + break; + } else { + target_size *= target[i]; + } + } + if (shape_is_known(src) && target_shape_is_known) { + CHECK_EQ(src.Size(), target_size) << "Cannot reshape array of size " + << src.Size() << " into shape " << target; + *output = TShape(target.begin(), target.end()); + return true; + } else if (!shape_is_known(src) || target.ndim() == -1) { + return false; + } else { + int unknown_axis = -1; + dim_t known_dim_size_prod = 1; + std::vector output_shape_vector; + int src_inx = 0; + for (int i = 0; i < target.ndim(); ++i) { + dim_t proposed_dim = target[i]; + CHECK(proposed_dim >= -6) + << "Dimension size must be greater than -6, received " << proposed_dim; + if (proposed_dim == -1) { + // infer the known dimension + CHECK_LT(unknown_axis, 0) + << "One and only one dim can be inferred"; + unknown_axis = output_shape_vector.size(); + output_shape_vector.push_back(1); + src_inx++; + } else if (proposed_dim == -2) { + // copy the dimension from src to output + CHECK_LT(src_inx, src.ndim()) + << "Unmatching dimension of proposed new shape"; + known_dim_size_prod *= src[src_inx]; + output_shape_vector.push_back(src[src_inx++]); + } else if (proposed_dim == -3) { + // skip the source dimension if and only if it is one + CHECK_EQ(src[src_inx], 1) + <<"-3 index should only be used to skip dimision size 1"; + src_inx++; + } else if (proposed_dim == -4) { + // copy all remaining dims from source + while (src_inx < src.ndim()) { + known_dim_size_prod *= src[src_inx]; + const int dn = src[src_inx++]; + output_shape_vector.push_back(dn); + } + } else if (proposed_dim == -5) { + // merge two dims from source + CHECK_LT(src_inx, src.ndim()-1) + <<"Not enough dimensions left for the product"; + const int d1 = src[src_inx++]; + const int d2 = src[src_inx++]; + if (!mxnet::dim_size_is_known(d1) || !mxnet::dim_size_is_known(d2)) { + CHECK_LT(unknown_axis, 0) + << "One and only one dim can be inferred"; + unknown_axis = output_shape_vector.size(); + output_shape_vector.push_back(-1); + } else { + known_dim_size_prod *= d1*d2; + output_shape_vector.push_back(d1 * d2); + } + } else if (proposed_dim == -6) { + // split the source dim s into two dims + // read the left dim and then the right dim (either can be -1) + CHECK_LT(i + 2, target.ndim()); + CHECK_LT(src_inx, src.ndim()); + const int d0 = src[src_inx++]; + dim_t d1 = target[++i]; + dim_t d2 = target[++i]; + CHECK(d1 != -1 || d2 != -1) << "Split dims cannot both be -1."; + if (d1 == -1 && d0 >= 0) d1 = d0 / d2; // d0 must be known to do this + if (d2 == -1 && d0 >= 0) d2 = d0 / d1; // d0 must be known to do this + CHECK(d1 * d2 == static_cast(d0) || static_cast(d0) == dim_t(-1)) + <<"Split dims " << d1 << ", " << d2 << " do not divide original dim " << d0; + if (d1 == -1) { + CHECK_LT(unknown_axis, 0) + << "One and only one dim can be inferred"; + unknown_axis = output_shape_vector.size(); + } else if (d2 == -1) { + CHECK_LT(unknown_axis, 0) + << "One and only one dim can be inferred"; + unknown_axis = output_shape_vector.size() + 1; + } + known_dim_size_prod *= d0 == -1 ? 1 : d0; + output_shape_vector.push_back(d1); + output_shape_vector.push_back(d2); + } else { + // greater than 0, new shape + known_dim_size_prod *= proposed_dim; + output_shape_vector.push_back(proposed_dim); + src_inx++; + } + } + + if (unknown_axis > -1) { + // if the input in zero size tensor, the output must be of known shape of zero size + CHECK_NE(known_dim_size_prod, 0) << "Cannot reshape array of size " + << src.Size() << " into shape " << target; + CHECK(src.Size() % known_dim_size_prod == 0) + << "Cannot reshape array of size " << src.Size() << " into shape " << target; + output_shape_vector[unknown_axis] = src.Size() / known_dim_size_prod; + } + + *output = mxnet::TShape(output_shape_vector.begin(), output_shape_vector.end()); + CHECK_EQ((*output).Size(), src.Size()) + << "Target output shape of size " << (*output).Size() + << " does not match the input shape of size " << src.Size(); + return true; + } +} + +bool NumpyXReshapeShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]"; + CHECK_EQ(out_attrs->size(), 1U); + const NumpyXReshapeParam& param = nnvm::get(attrs.parsed); + // sanity check + bool has_unknown_dim_size = false; + for (int i = 0; i < param.newshape.ndim(); ++i) { + if (param.newshape[i] < 0) { + CHECK_GE(param.newshape[i], -6) + << "Dimension size must be greater than or equal to -6"; + if (param.newshape[i] == -1) { + CHECK(!has_unknown_dim_size) << "Can only specify one unknown dimension"; + has_unknown_dim_size = true; + } + } + } + + mxnet::TShape output_shape; + bool success = NumpyXReshapeInferShape(in_attrs->at(0), param.newshape, &output_shape); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, output_shape); + if (!success) { + success = ReverseReshapeInferShape(&(*in_attrs)[0], (*out_attrs)[0]); + } + return success; +} + bool NumpyReshapeShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_attrs, mxnet::ShapeVector* out_attrs) { CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]"; CHECK_EQ(out_attrs->size(), 1U); const NumpyReshapeParam& param = nnvm::get(attrs.parsed); - // sanity check + // sanity check bool has_unknown_dim_size = false; for (int i = 0; i < param.newshape.ndim(); ++i) { if (param.newshape[i] < 0) { - CHECK_EQ(param.newshape[i], -1) << "The shape dimension size to inferred must be -1"; + CHECK_EQ(param.newshape[i], -1) << "The shape dimension size to inferred must be -1"; CHECK(!has_unknown_dim_size) << "Can only specify one unknown dimension"; has_unknown_dim_size = true; } @@ -184,6 +333,33 @@ NNVM_REGISTER_OP(_np_reshape) .add_argument("a", "NDArray-or-Symbol", "Array to be reshaped.") .add_arguments(NumpyReshapeParam::__FIELDS__()); +DMLC_REGISTER_PARAMETER(NumpyXReshapeParam); + +NNVM_REGISTER_OP(_npx_reshape) +.add_alias("_npi_reshape") +.describe(R"code()code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyXReshapeShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_reshape"}) +.set_attr("FCompute", UnaryOp::IdentityCompute) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.set_attr("FInplaceIdentity", + [](const NodeAttrs& attrs){ + return std::vector{true}; + }) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a"}; + }) +.add_argument("a", "NDArray-or-Symbol", "Array to be reshaped.") +.add_arguments(NumpyXReshapeParam::__FIELDS__()); + bool NumpySqueezeShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, mxnet::ShapeVector *out_attrs) { diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index f192560f4ac9..33a7f0495ab5 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -42,5 +42,8 @@ NNVM_REGISTER_OP(_npi_concatenate) NNVM_REGISTER_OP(_backward_np_concat) .set_attr("FCompute", ConcatGradCompute); +NNVM_REGISTER_OP(_npx_reshape) +.set_attr("FCompute", UnaryOp::IdentityCompute); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index f02a38ac07c4..45aec870b3d9 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -507,6 +507,7 @@ Example:: [1., 3.]] )code" ADD_FILELINE) .add_alias("_npx_slice") +.add_alias("_npi_slice") .set_attr_parser(ParamParser) .set_attr("FInferShape", SliceOpShape) .set_attr("FInferType", ElemwiseType<1, 1>) diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py new file mode 100644 index 000000000000..5fabaeb26941 --- /dev/null +++ b/tests/python/unittest/test_numpy_gluon.py @@ -0,0 +1,164 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: skip-file +from __future__ import absolute_import +from __future__ import division + +import mxnet as mx +from mxnet import gluon, autograd, np, npx +from mxnet.test_utils import use_np, assert_almost_equal +from common import with_seed +import random + + +@with_seed() +@use_np +def test_symbolic_basic_slicing(): + def get_slice_index(shape): + index = [] + step_switch = random.randint(0,1) + step = None if step_switch == 0 else [] + for i in range(len(shape)): + if shape[i] == 0: + index.append(slice(0,1)) + continue + if random.randint(0, 5) > 4: + index.append(random.randint(0, shape[i]-1)) + continue + s = random.randint(0, shape[i]-1) + e = random.randint(s+1, shape[i]) + if step_switch == 1: + index.append(slice(s, e, 1)) + elif step_switch == -1: + if e == shape[i]: + e -= 1 + s -= 1 + if s == -1: + s = None + index.append(slice(e, s, -1)) + else: + index.append(slice(s, e)) + return tuple(index) + + shapes = [ + (4, 6, 8, 9), + (1, 1, 1, 6), + (10, 20, 30), + ] + for shape in shapes: + for i in range(10): + index = get_slice_index(shape) + # Test basic slicing on single symbol + class TestSlicingSingleSymbol(gluon.HybridBlock): + def __init__(self, **kwargs): + super(TestSlicingSingleSymbol, self).__init__(**kwargs) + + def hybrid_forward(self, F, x): + x = x[:] + x = x[index] + return x + + net = TestSlicingSingleSymbol() + x = mx.nd.random.normal(shape=shape).as_np_ndarray() + x.attach_grad() + with autograd.record(): + imperative_out = net(x) + imperative_out.backward() + imperative_grad = x.grad.asnumpy() + + y = x + y.attach_grad() + net2 = TestSlicingSingleSymbol() + net2.hybridize() + with autograd.record(): + symbolic_out = net2(y) + symbolic_out.backward() + symbolic_grad = y.grad.asnumpy() + assert_almost_equal(imperative_out.asnumpy(), symbolic_out.asnumpy(), rtol=1e-3, atol=1e-5) + assert_almost_equal(imperative_grad, symbolic_grad, rtol=1e-3, atol=1e-5) + + # Test save and load + net2.export('gluon') + net2_imported = gluon.SymbolBlock.imports('gluon-symbol.json', 'data', 'gluon-0000.params') + assert_almost_equal(net2(x).asnumpy(), net2_imported(x).asnumpy()) + + #Test slicing on symbol with list of outputs + slice_on_first_dim = index[0] if isinstance(index[0], slice) else slice(index[0], index[0] + 1) + class TestSlicingListOutputs(gluon.HybridBlock): + def __init__(self, **kwargs): + super(TestSlicingListOutputs, self).__init__(**kwargs) + + def hybrid_forward(self, F, x): + x = F.np.split(x, shape[0]) + x = x[slice_on_first_dim] + x = F.np.concatenate(x) + return F.np.sum(x) + + net = TestSlicingListOutputs() + x = mx.nd.random.normal(shape=shape).as_np_ndarray() + x.attach_grad() + with autograd.record(): + imperative_out = net(x) + imperative_out.backward() + imperative_grad = x.grad.asnumpy() + + y = x + y.attach_grad() + net2 = TestSlicingListOutputs() + net2.hybridize() + with autograd.record(): + symbolic_out = net2(y) + symbolic_out.backward() + symbolic_grad = y.grad.asnumpy() + assert_almost_equal(imperative_out.asnumpy(), symbolic_out.asnumpy(), rtol=1e-3, atol=1e-5) + assert_almost_equal(imperative_grad, symbolic_grad, rtol=1e-3, atol=1e-5) + + # Test slicing on length one list of symbol (flag enabled list) + class TestSlicingSingletonList(gluon.HybridBlock): + def __init__(self, **kwargs): + super(TestSlicingSingletonList, self).__init__(**kwargs) + + def hybrid_forward(self, F, x): + x = F.np.split(x, 1) + x = x[0] + x = x[index] + return F.np.sum(x) + + net = TestSlicingSingletonList() + x = mx.nd.random.normal(shape=shape).as_np_ndarray() + x.attach_grad() + with autograd.record(): + imperative_out = net(x) + imperative_out.backward() + imperative_grad = x.grad.asnumpy() + + y = x + y.attach_grad() + net2 = TestSlicingSingletonList() + net2.hybridize() + with autograd.record(): + symbolic_out = net2(y) + symbolic_out.backward() + symbolic_grad = y.grad.asnumpy() + assert_almost_equal(imperative_out.asnumpy(), symbolic_out.asnumpy(), rtol=1e-3, atol=1e-5) + assert_almost_equal(imperative_grad, symbolic_grad, rtol=1e-3, atol=1e-5) + + +if __name__ == '__main__': + import nose + nose.runmodule() \ No newline at end of file diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 2291bcdb6d3d..70b216f9c606 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -880,6 +880,61 @@ def get_new_shape(shape, axis): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_npx_reshape(): + class TestNumpyXReshape(HybridBlock): + def __init__(self, newshape): + super(TestNumpyXReshape, self).__init__() + self._newshape = newshape + + def hybrid_forward(self, F, a, *args, **kwargs): + return F.npx.reshape(a, self._newshape) + + test_cases = [ + [(2, 3, 5, 5), (-2, -1), (2, 75)], + [(2, 3, 5, 5), (-2, -2, -1), (2, 3, 25)], + [(5, 3, 4, 5), (-2, -1, -2), (5, 15, 4)], + [(2, 3, 5, 4), (-1, -2, -2), (8, 3, 5)], + [(2, 3, 5, 5), (-2, -2, -2, -2), (2, 3, 5, 5)], + [(2, 1, 4, 5), (-2, -3, -2, -2), (2, 4, 5)], + [(1, 1, 4, 1), (-3, -3, -2, -2), (4, 1)], + [(1, 1, 1, 1), (-3, -3, -3, -3), ()], + [(2, 4, 5, 3), (-1, 2, 2, 1), (30, 2, 2, 1)], + [(2, 3, 5, 6), (-4,), (2, 3, 5, 6)], + [(2, 3, 5, 6), (6, 1, -4), (6, 1, 5, 6)], + [(2, 3, 5, 6), (-5, -5), (6, 30)], + [(2, 3, 5, 6), (-5, -1), (6, 30)], + [(64,), (-6, 16, 4), (16, 4)], + [(64,), (-6, 16, -1), (16, 4)], + [(64, 1, 2, 3), (-6, 16, -1, -4), (16, 4, 1, 2, 3)] + ] + for hybridize in [True, False]: + for shape, newshape, expected_ret_shape in test_cases: + # test gluon + test_reshape = TestNumpyXReshape(newshape=newshape) + if hybridize: + test_reshape.hybridize() + + a = mx.nd.random.uniform(shape=shape).as_np_ndarray() + a.attach_grad() + with mx.autograd.record(): + y = test_reshape(a) + + assert y.shape == expected_ret_shape + assert_almost_equal(y.asnumpy(), a.asnumpy().reshape(expected_ret_shape), rtol=1e-3, atol=1e-5) + + # test backward + mx.autograd.backward(y) + expected_grad = _np.ones(shape) + assert_almost_equal(a.grad.asnumpy(), expected_grad, rtol=1e-3, atol=1e-5) + + # test imperative + npx_out = npx.reshape(a, newshape) + expected_out = _np.reshape(a.asnumpy(), expected_ret_shape) + assert_almost_equal(npx_out.asnumpy(), expected_out, rtol=1e-3, atol=1e-5) + + if __name__ == '__main__': import nose nose.runmodule() From 2f8eaddcab15d1cfece062ab2e388ffeda7f6a01 Mon Sep 17 00:00:00 2001 From: Mike Mao Date: Thu, 15 Aug 2019 06:04:12 +0000 Subject: [PATCH 2/3] Change the error raised in _op.split --- python/mxnet/ndarray/numpy/_op.py | 2 +- python/mxnet/symbol/numpy/_symbol.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index d7f3fd1ace54..fdb22868e3b8 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -680,7 +680,7 @@ def split(ary, indices_or_sections, axis=0): raise ValueError('indices_or_sections must either int or tuple of ints') ret = _npi.split(ary, indices, axis, False) if not isinstance(ret, list): - return [ret] + raise NotImplementedError('Output of split should be list, get a return type %s'%(str(type(ret)))) return ret diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index e9385d9d1a77..672c122d68c5 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -1452,6 +1452,8 @@ def concatenate(seq, axis=0, out=None): res : ndarray The concatenated array. """ + if len(seq) > 1: + return _npi.concatenate(*[seq[i] for i in range(len(seq))], dim=axis, out=out) return _npi.concatenate(*seq, dim=axis, out=out) From a9a104ba87be34feeb165440e5466072c20722f8 Mon Sep 17 00:00:00 2001 From: Mike Mao Date: Fri, 16 Aug 2019 05:28:16 +0000 Subject: [PATCH 3/3] Fix sanity issues --- python/mxnet/symbol/numpy/_symbol.py | 2 -- src/operator/numpy/np_matrix_op.cc | 16 ++++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 672c122d68c5..0a6bf301dd7c 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -52,7 +52,6 @@ def __init__(self, handle): def __getitem__(self, key): # pylint: disable = too-many-return-statements, inconsistent-return-statements num_outputs = len(self) - # print("Num of outputs is ", num_outputs) if num_outputs == 1: # pylint: disable = too-many-nested-blocks # If number of output is one and is not a list, perform ndarray basic slicing if not self._output_is_list: @@ -124,7 +123,6 @@ def __getitem__(self, key): # pylint: disable = too-many-return-statements, inco else: raise NotImplementedError - def __setitem__(self, key, value): raise NotImplementedError diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index a99af063be2c..961e24bc0a9e 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -199,7 +199,7 @@ bool NumpyXReshapeInferShape(const mxnet::TShape& src, const int d2 = src[src_inx++]; if (!mxnet::dim_size_is_known(d1) || !mxnet::dim_size_is_known(d2)) { CHECK_LT(unknown_axis, 0) - << "One and only one dim can be inferred"; + << "One and only one dim can be inferred"; unknown_axis = output_shape_vector.size(); output_shape_vector.push_back(-1); } else { @@ -221,11 +221,11 @@ bool NumpyXReshapeInferShape(const mxnet::TShape& src, <<"Split dims " << d1 << ", " << d2 << " do not divide original dim " << d0; if (d1 == -1) { CHECK_LT(unknown_axis, 0) - << "One and only one dim can be inferred"; + << "One and only one dim can be inferred"; unknown_axis = output_shape_vector.size(); } else if (d2 == -1) { CHECK_LT(unknown_axis, 0) - << "One and only one dim can be inferred"; + << "One and only one dim can be inferred"; unknown_axis = output_shape_vector.size() + 1; } known_dim_size_prod *= d0 == -1 ? 1 : d0; @@ -242,7 +242,7 @@ bool NumpyXReshapeInferShape(const mxnet::TShape& src, if (unknown_axis > -1) { // if the input in zero size tensor, the output must be of known shape of zero size CHECK_NE(known_dim_size_prod, 0) << "Cannot reshape array of size " - << src.Size() << " into shape " << target; + << src.Size() << " into shape " << target; CHECK(src.Size() % known_dim_size_prod == 0) << "Cannot reshape array of size " << src.Size() << " into shape " << target; output_shape_vector[unknown_axis] = src.Size() / known_dim_size_prod; @@ -252,7 +252,7 @@ bool NumpyXReshapeInferShape(const mxnet::TShape& src, CHECK_EQ((*output).Size(), src.Size()) << "Target output shape of size " << (*output).Size() << " does not match the input shape of size " << src.Size(); - return true; + return true; } } @@ -267,7 +267,7 @@ bool NumpyXReshapeShape(const nnvm::NodeAttrs& attrs, for (int i = 0; i < param.newshape.ndim(); ++i) { if (param.newshape[i] < 0) { CHECK_GE(param.newshape[i], -6) - << "Dimension size must be greater than or equal to -6"; + << "Dimension size must be greater than or equal to -6"; if (param.newshape[i] == -1) { CHECK(!has_unknown_dim_size) << "Can only specify one unknown dimension"; has_unknown_dim_size = true; @@ -290,11 +290,11 @@ bool NumpyReshapeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]"; CHECK_EQ(out_attrs->size(), 1U); const NumpyReshapeParam& param = nnvm::get(attrs.parsed); - // sanity check + // sanity check bool has_unknown_dim_size = false; for (int i = 0; i < param.newshape.ndim(); ++i) { if (param.newshape[i] < 0) { - CHECK_EQ(param.newshape[i], -1) << "The shape dimension size to inferred must be -1"; + CHECK_EQ(param.newshape[i], -1) << "The shape dimension size to inferred must be -1"; CHECK(!has_unknown_dim_size) << "Can only specify one unknown dimension"; has_unknown_dim_size = true; }