diff --git a/pymatbridge/matlab_magic.py b/pymatbridge/matlab_magic.py index 5d1d703..7d94e89 100644 --- a/pymatbridge/matlab_magic.py +++ b/pymatbridge/matlab_magic.py @@ -130,8 +130,7 @@ def set_matlab_var(self, name, value): """ Set up a variable in Matlab workspace """ - run_dict = self.Matlab.run_func("pymat_set_variable.m", - {'name':name, 'value':value}) + run_dict = self.Matlab.set_variable(name, value) if run_dict['success'] == 'false': raise MatlabInterperterError(line, run_dict['content']['stdout']) @@ -194,11 +193,9 @@ def matlab(self, line, cell=None, local_ns=None): val = local_ns[input] except KeyError: val = self.shell.user_ns[input] - - # To make an array JSON serializable - if (isinstance(val, np.ndarray)): - val = val.tolist() - + # The _Session.set_variable function which this calls + # should correctly detect numpy arrays and serialize them + # as json correctly. self.set_matlab_var(input, val) else: diff --git a/pymatbridge/pymatbridge.py b/pymatbridge/pymatbridge.py index 8b52d56..5121441 100644 --- a/pymatbridge/pymatbridge.py +++ b/pymatbridge/pymatbridge.py @@ -14,12 +14,31 @@ import subprocess import sys import json +from uuid import uuid4 -# JSON encoder extension to handle complex numbers -class ComplexEncoder(json.JSONEncoder): +try: + from numpy import ndarray, generic +except ImportError: + class ndarray: + pass + generic = ndarray + +try: + from scipy.sparse import spmatrix +except ImportError: + class spmatrix: + pass + + +# JSON encoder extension to handle complex numbers and numpy arrays +class PymatEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, complex): return {'real':obj.real, 'imag':obj.imag} + if isinstance(obj, ndarray): + return obj.tolist() + if isinstance(obj, generic): + return obj.item() # Handle the default case return json.JSONEncoder.default(self, obj) @@ -134,7 +153,7 @@ def start(self): return False def _response(self, **kwargs): - req = json.dumps(kwargs, cls=ComplexEncoder) + req = json.dumps(kwargs, cls=PymatEncoder) self.socket.send_string(req) resp = self.socket.recv_string() return resp @@ -154,7 +173,7 @@ def is_connected(self): time.sleep(2) return False - req = json.dumps(dict(cmd="connect"), cls=ComplexEncoder) + req = json.dumps(dict(cmd="connect"), cls=PymatEncoder) self.socket.send_string(req) start_time = time.time() @@ -192,6 +211,24 @@ def run_code(self, code): def get_variable(self, varname): return self._json_response(cmd='get_var', varname=varname)['var'] + def set_variable(self, varname, value): + if isinstance(value, spmatrix): + return self._set_sparse_variable(varname, value) + return self.run_func('pymat_set_variable.m', + {'name': varname, 'value': value}) + + def _set_sparse_variable(self, varname, value): + value = value.todok() + prefix = 'pymatbridge_temp_sparse_%s_' % uuid4().hex + self.set_variable(prefix + 'keys', value.keys()) + # correct for 1-indexing in MATLAB + self.run_code('{0}keys = {0}keys + 1;'.format(prefix)) + self.set_variable(prefix + 'values', value.values()) + cmd = "{1} = sparse({0}keys(:, 1), {0}keys(:, 2), {0}values');" + result = self.run_code(cmd.format(prefix, varname)) + self.run_code('clear {0}keys {0}values'.format(prefix)) + return result + class Matlab(_Session): def __init__(self, executable='matlab', socket_addr=None,