Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions pymatbridge/matlab_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ def set_matlab_var(self, name, value):
"""
Set up a variable in Matlab workspace
"""
run_dict = self.Matlab.run_func("pymat_set_variable.m",
{'name':name, 'value':value})
run_dict = self.Matlab.set_variable(name, value)

if run_dict['success'] == 'false':
raise MatlabInterperterError(line, run_dict['content']['stdout'])
Expand Down Expand Up @@ -194,11 +193,9 @@ def matlab(self, line, cell=None, local_ns=None):
val = local_ns[input]
except KeyError:
val = self.shell.user_ns[input]

# To make an array JSON serializable
if (isinstance(val, np.ndarray)):
val = val.tolist()

# The _Session.set_variable function which this calls
# should correctly detect numpy arrays and serialize them
# as json correctly.
self.set_matlab_var(input, val)

else:
Expand Down
45 changes: 41 additions & 4 deletions pymatbridge/pymatbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,31 @@
import subprocess
import sys
import json
from uuid import uuid4

# JSON encoder extension to handle complex numbers
class ComplexEncoder(json.JSONEncoder):
try:
from numpy import ndarray, generic
except ImportError:
class ndarray:
pass
generic = ndarray

try:
from scipy.sparse import spmatrix
except ImportError:
class spmatrix:
pass
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep - this all seems OK to me. Hard to imagine anyone would want to use this without numpy/scipy, but it's good not to require that.



# JSON encoder extension to handle complex numbers and numpy arrays
class PymatEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, complex):
return {'real':obj.real, 'imag':obj.imag}
if isinstance(obj, ndarray):
return obj.tolist()
if isinstance(obj, generic):
return obj.item()
# Handle the default case
return json.JSONEncoder.default(self, obj)

Expand Down Expand Up @@ -134,7 +153,7 @@ def start(self):
return False

def _response(self, **kwargs):
req = json.dumps(kwargs, cls=ComplexEncoder)
req = json.dumps(kwargs, cls=PymatEncoder)
self.socket.send_string(req)
resp = self.socket.recv_string()
return resp
Expand All @@ -154,7 +173,7 @@ def is_connected(self):
time.sleep(2)
return False

req = json.dumps(dict(cmd="connect"), cls=ComplexEncoder)
req = json.dumps(dict(cmd="connect"), cls=PymatEncoder)
self.socket.send_string(req)

start_time = time.time()
Expand Down Expand Up @@ -192,6 +211,24 @@ def run_code(self, code):
def get_variable(self, varname):
return self._json_response(cmd='get_var', varname=varname)['var']

def set_variable(self, varname, value):
if isinstance(value, spmatrix):
return self._set_sparse_variable(varname, value)
return self.run_func('pymat_set_variable.m',
{'name': varname, 'value': value})

def _set_sparse_variable(self, varname, value):
value = value.todok()
prefix = 'pymatbridge_temp_sparse_%s_' % uuid4().hex
self.set_variable(prefix + 'keys', value.keys())
# correct for 1-indexing in MATLAB
self.run_code('{0}keys = {0}keys + 1;'.format(prefix))
self.set_variable(prefix + 'values', value.values())
cmd = "{1} = sparse({0}keys(:, 1), {0}keys(:, 2), {0}values');"
result = self.run_code(cmd.format(prefix, varname))
self.run_code('clear {0}keys {0}values'.format(prefix))
return result


class Matlab(_Session):
def __init__(self, executable='matlab', socket_addr=None,
Expand Down