Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from . import name
# use mx.sym as short for symbol
from . import symbol as sym
from .symbol.numpy import _symbol as np_symbol
from . import symbol
from . import symbol_doc
from . import io
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _set_np_ndarray_class(cls):
_np_ndarray_cls = cls


def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op):
def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op, output_is_list):
"""ctypes implementation of imperative invoke wrapper"""
if out is not None:
original_output = out
Expand Down Expand Up @@ -102,7 +102,7 @@ def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op):
create_ndarray_fn = _np_ndarray_cls if is_np_op else _ndarray_cls
if original_output is not None:
return original_output
if num_output.value == 1:
if num_output.value == 1 and not output_is_list:
return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle),
stype=out_stypes[0])
else:
Expand Down
7 changes: 6 additions & 1 deletion python/mxnet/_ctypes/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _set_np_symbol_class(cls):
_np_symbol_cls = cls


def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op):
def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op, output_is_list):
sym_handle = SymbolHandle()
check_call(_LIB.MXSymbolCreateAtomicSymbol(
ctypes.c_void_p(handle),
Expand All @@ -138,6 +138,11 @@ def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op):
'Symbols either as positional or keyword arguments, not both')
create_symbol_fn = _np_symbol_cls if is_np_op else _symbol_cls
s = create_symbol_fn(sym_handle)
if is_np_op:
if output_is_list:
s._output_is_list = True
else:
s._output_is_list = False
if args:
s._compose(*args, name=name)
elif kwargs:
Expand Down
10 changes: 10 additions & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,12 +755,22 @@ def write_all_str(module_file, module_all_list):

_NP_INTERNAL_OP_PREFIX = '_npi_'

_NP_OUTPUT_IS_LIST_OPERATORS = ['npi_split']


def _is_np_op(op_name):
return op_name.startswith(_NP_OP_PREFIX) or op_name.startswith(_NP_EXT_OP_PREFIX)\
or op_name.startswith(_NP_INTERNAL_OP_PREFIX)


def _output_is_list(op_name):
if _is_np_op(op_name):
for target_operator_name in _NP_OUTPUT_IS_LIST_OPERATORS:
if target_operator_name in op_name:
return True
return False


def _get_op_submodule_name(op_name, op_name_prefix, submodule_name_list):
"""Get the submodule name of a specific op"""
assert op_name.startswith(op_name_prefix)
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/cython/ndarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ cdef class CachedOp:
return [NewArray(p_output_vars[i], p_output_stypes[i], self.is_np_sym) for i in range(num_output)]


def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op=0):
def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op=0, output_is_list=0):
"""cython implementation of imperative invoke wrapper"""
cdef unsigned long long ihandle = handle
cdef OpHandle chandle = <OpHandle>ihandle
Expand Down Expand Up @@ -221,7 +221,7 @@ def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op=0):

if original_output is not None:
return original_output
if num_output == 1:
if num_output == 1 and not output_is_list:
return NewArray(p_output_vars[0], p_output_stypes[0], is_np_op)
else:
return [NewArray(p_output_vars[i], p_output_stypes[i], is_np_op) for i in range(num_output)]
11 changes: 8 additions & 3 deletions python/mxnet/cython/symbol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,20 @@ def _set_np_symbol_class(cls):
_np_symbol_cls = cls


cdef NewSymbol(SymbolHandle handle, int is_np_sym=0):
cdef NewSymbol(SymbolHandle handle, int is_np_sym=0, int output_is_list=0):
"""Create a new symbol given handle"""
create_symbol_fn = _np_symbol_cls if is_np_sym else _symbol_cls
sym = create_symbol_fn(None)
if is_np_sym:
if output_is_list:
sym._output_is_list = True
else:
sym._output_is_list = False
(<SymbolBase>sym).chandle = handle
return sym


def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op=0):
def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op=0, output_is_list=0):
cdef unsigned long long ihandle = handle
cdef OpHandle chandle = <OpHandle>ihandle
cdef vector[string] ckeys
Expand Down Expand Up @@ -151,4 +156,4 @@ def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op=0):
&csym_keys[0] if csym_keys.size() != 0 else NULL,
&sym_args[0] if sym_args.size() != 0 else NULL))

return NewSymbol(ret_handle, is_np_op)
return NewSymbol(ret_handle, is_np_op, output_is_list)
9 changes: 6 additions & 3 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from collections import OrderedDict

from ..base import mx_real_t, MXNetError
from .. import symbol, ndarray, initializer
from .. import symbol, ndarray, initializer, np_symbol
from ..symbol import Symbol
from ..ndarray import NDArray
from .. import name as _name
Expand Down Expand Up @@ -1055,15 +1055,18 @@ def imports(symbol_file, input_names, param_file=None, ctx=None):
... 'net1-symbol.json', ['data'], 'net1-0001.params')
>>> out2 = net2(x)
"""
sym = symbol.load(symbol_file)
if is_np_array():
sym = np_symbol.load(symbol_file)
else:
sym = symbol.load(symbol_file)
if isinstance(input_names, str):
input_names = [input_names]
if param_file is None:
# Get a valid type inference by using fp32
inputs = [symbol.var(i, dtype=mx_real_t) for i in input_names]
else:
# Do not specify type, rely on saved params type instead
inputs = [symbol.var(i) for i in input_names]
inputs = [symbol.var(i).as_np_ndarray() if is_np_array() else symbol.var(i) for i in input_names]
ret = SymbolBlock(sym, inputs)
if param_file is not None:
ret.collect_params().load(param_file, ctx=ctx, cast_dtype=True, dtype_source='saved')
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def split(ary, indices_or_sections, axis=0):
raise ValueError('indices_or_sections must either int or tuple of ints')
ret = _npi.split(ary, indices, axis, False)
if not isinstance(ret, list):
return [ret]
raise NotImplementedError('Output of split should be list, get a return type %s'%(str(type(ret))))
return ret


Expand Down
7 changes: 4 additions & 3 deletions python/mxnet/ndarray/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ._internal import NDArrayBase, _imperative_invoke # pylint: disable=unused-import
from ..ndarray_doc import _build_doc

from ..base import mx_uint, check_call, _LIB, py_str, _init_op_module, _Null, _is_np_op # pylint: disable=unused-import
from ..base import mx_uint, check_call, _LIB, py_str, _init_op_module, _Null, _is_np_op, _output_is_list # pylint: disable=unused-import
from ..util import use_np_shape # pylint: disable=unused-import


Expand Down Expand Up @@ -176,6 +176,7 @@ def _generate_ndarray_function_code(handle, op_name, func_name, signature_only=F

code = []
is_np_op = _is_np_op(op_name)
output_is_list = _output_is_list(op_name)
doc_str_idx = 1
if is_np_op:
doc_str_idx = 2
Expand Down Expand Up @@ -241,8 +242,8 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
{verify_fn}("{op_name}", "{func_name}", ndargs, out)
""".format(verify_fn=verify_ndarrays_fn, op_name=op_name, func_name=func_name))
code.append("""
return _imperative_invoke(%d, ndargs, keys, vals, out, %s)"""%(
handle.value, str(is_np_op)))
return _imperative_invoke(%d, ndargs, keys, vals, out, %s, %s)"""%(
handle.value, str(is_np_op), str(output_is_list)))
else:
code.append("""
return (0,)""")
Expand Down
Loading