diff --git a/Cargo.toml b/Cargo.toml index da3c457..d45f195 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,9 @@ description = "A library for defining domain specific languages in a polymorphic repository = "https://github.com/mlb2251/lambdas" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +python = ["dep:pyo3"] +default = [] [dependencies] serde_json = {version = "1.0", features = ["preserve_order"]} @@ -15,6 +18,7 @@ serde = {version = "1.0", features = ["derive"]} once_cell = "1.16.0" string_cache = "0.8.4" rustc-hash = "1.1.0" +pyo3 = { version = "0.25.1", features = ["auto-initialize"], optional=true} # [profile.release] # debug = true # for flamegraphs diff --git a/src/domains/mod.rs b/src/domains/mod.rs index 9f9c938..ce9f377 100644 --- a/src/domains/mod.rs +++ b/src/domains/mod.rs @@ -2,3 +2,9 @@ // just register each domain here by including it with `pub mod domain_name;` pub mod simple; pub mod prim_lists; + +#[cfg(feature = "python")] +pub mod simple_python; + +#[cfg(feature = "python")] +pub mod py; \ No newline at end of file diff --git a/src/domains/py.rs b/src/domains/py.rs new file mode 100644 index 0000000..25a733b --- /dev/null +++ b/src/domains/py.rs @@ -0,0 +1,120 @@ +use crate::*; +use std::collections::HashSet; +use std::sync::Arc; + + +use pyo3::prelude::*; +use pyo3::types::PyAny; + +#[derive(Clone,Debug, PartialEq, Eq, Hash)] +pub enum PyVal { + Int(i32), + List(Vec), +} + +#[derive(Clone,Debug, PartialEq, Eq, Hash)] +pub enum PyType { + TInt, + TList +} + + +type Val = crate::eval::Val; + + +#[cfg(feature = "python")] +pub fn create_python_production( + name: Symbol, + tp: SlowType, + lazy_args: Option<&[usize]>, + pyfunc: Py, +) -> Production { + let arity = tp.arity(); + let lazy: HashSet = lazy_args + .map(|xs| xs.iter().copied().collect()) + .unwrap_or_default(); + + use crate::eval::{CurriedFn, Val}; + + Production { + name: name.clone(), + val: Val::PrimFun(CurriedFn::::new(name.clone(), arity)), + tp, + arity, + lazy_args: lazy, + fn_ptr: Some(FnPtr::Python(Arc::new(pyfunc))) + } +} + +// From impls are needed for unwrapping values. We can assume the program +// has been type checked so it's okay to panic if the type is wrong. Each val variant +// must map to exactly one unwrapped type (though it doesnt need to be one to one in the +// other direction) +impl FromVal for i32 { + fn from_val(v: Val) -> Result { + match v { + Dom(PyVal::Int(i)) => Ok(i), + _ => Err("from_val_to_i32: not an int".into()) + } + } +} +impl> FromVal for Vec { + fn from_val(v: Val) -> Result { + match v { + Dom(PyVal::List(v)) => v.into_iter().map(|v| T::from_val(v)).collect(), + _ => Err("from_val_to_vec: not a list".into()) + } + } +} + +impl From for Val { + fn from(i: i32) -> Val { + Dom(PyVal::Int(i)) + } +} +impl> From> for Val { + fn from(vec: Vec) -> Val { + Dom(PyVal::List(vec.into_iter().map(|v| v.into()).collect())) + } +} + +impl Domain for PyVal { + type Data = (); + + #[cfg(feature = "python")] + fn py_val_to_py(py: Python<'_>, v: crate::eval::Val) -> PyResult> { + crate::domains::simple_python::val_to_py(py, v) + } + + #[cfg(feature = "python")] + fn py_py_to_val(obj: &Bound<'_, PyAny>) -> Result, String> { + crate::domains::simple_python::py_to_val(obj) + } + + fn new_dsl() -> DSL { + let prods = vec![]; + let dsl = DSL::new(prods); + dsl + } + + fn val_of_prim_fallback(p: &Symbol) -> Option { + None + } + + fn type_of_dom_val(&self) -> SlowType { + match self { + PyVal::Int(_) => SlowType::base(Symbol::from("int")), + PyVal::List(xs) => { + let elem_tp = if xs.is_empty() { + SlowType::Var(0) // (list t0) + } else { + let result = Self::type_of_dom_val(&xs.first().unwrap().clone().dom().unwrap()); + assert!(xs.iter().all(|v| result == Self::type_of_dom_val(&v.clone().dom().unwrap()))); + result + }; + SlowType::Term("list".into(),vec![elem_tp]) + }, + } + } + +} \ No newline at end of file diff --git a/src/domains/simple_python.rs b/src/domains/simple_python.rs new file mode 100644 index 0000000..32c6573 --- /dev/null +++ b/src/domains/simple_python.rs @@ -0,0 +1,43 @@ +#![cfg(feature = "python")] + +use pyo3::prelude::*; +use pyo3::types::PyList; +use pyo3::conversion::IntoPyObjectExt; + +use crate::domains::py::PyVal; +use crate::eval; +type Val = eval::Val; +use PyVal::*; + + +pub fn val_to_py(py: Python<'_>, v: Val) -> PyResult> { + match v.dom().expect("Val should be Dom") { + Int(i) => { + Ok(i.into_py_any(py)?) + } + List(xs) => { + let mut elems: Vec> = Vec::with_capacity(xs.len()); + for x in xs { + elems.push(val_to_py(py, x.clone())?); + } + + let list_bound: Bound<'_, PyList> = PyList::new(py, &elems)?; + + Ok(list_bound.into_any().unbind()) + } + } +} + +pub fn py_to_val(obj: &Bound<'_, PyAny>) -> Result { + if let Ok(i) = obj.extract::() { + return Ok(Val::from(Int(i))); + } + if let Ok(list) = obj.downcast::() { + let mut out: Vec = Vec::with_capacity(list.len()); + for item in list.iter() { + out.push(py_to_val(&item)?); + } + return Ok(Val::from(List(out))); + } + Err("unsupported Python type for PyVal".into()) +} diff --git a/src/dsl.rs b/src/dsl.rs index b78495f..27086af 100644 --- a/src/dsl.rs +++ b/src/dsl.rs @@ -3,10 +3,23 @@ use crate::*; use std::collections::{HashMap, HashSet}; use std::fmt::{Debug}; use std::hash::Hash; +#[cfg(feature = "python")] +use std::sync::Arc; +#[cfg(feature = "python")] +use pyo3::prelude::*; +#[cfg(feature = "python")] +use pyo3::types::{PyList, PyTuple}; pub type DSLFn = fn(Env, &Evaluator) -> VResult; +#[derive(Clone)] +pub enum FnPtr { + Native(DSLFn), + #[cfg(feature = "python")] + Python(Arc>), +} + #[derive(Clone)] pub struct Production { pub name: Symbol, // eg "map" or "0" or "[1,2,3]" @@ -14,7 +27,7 @@ pub struct Production { pub tp: SlowType, pub arity: usize, pub lazy_args: HashSet, - pub fn_ptr: Option>, + pub fn_ptr: Option>, } impl Debug for Production { @@ -23,6 +36,40 @@ impl Debug for Production { } } +impl Production { + #[inline] + pub fn call(&self, args: Env, handle: &Evaluator) -> VResult { + match &self.fn_ptr{ + Some(FnPtr::Native(f)) => f(args, handle), + + #[cfg(feature = "python")] + Some(FnPtr::Python(pyf)) => { + Python::with_gil(|py| -> Result, String> { + // Env -> Python list + let mut elems: Vec> = Vec::with_capacity(args.len()); + for v in &args.env { + let obj = D::py_val_to_py(py, v.clone()) + .map_err(|e| e.to_string())?; + elems.push(obj); + } + + let tuple_bound:Bound = PyTuple::new(py, &elems) + .map_err(|e| e.to_string())?; + + let ret = (**pyf) + .bind(py) + .call1(tuple_bound) + .map_err(|e| e.to_string())?; + + // Python -> Val + D::py_py_to_val(&ret) + }) + } + None => Err("primitive has no function pointer".into()), + } + } +} + #[derive(Clone, Debug)] pub struct DSL { @@ -65,7 +112,7 @@ impl Production { tp, arity, lazy_args, - fn_ptr: Some(fn_ptr), + fn_ptr: Some(FnPtr::Native(fn_ptr)), } } @@ -85,6 +132,13 @@ impl DSL { self.productions.insert(entry.name.clone(), entry); } + /// NEW: add a constant (arity 0) to the DSL + pub fn add_constant(&mut self, name: Symbol, tp: SlowType, val: Val) { + assert_eq!(tp.arity(), 0); + let prod = Production::val_raw(name, tp, val); + self.add_entry(prod); + } + /// given a primitive's symbol return a runtime Val object. For function primitives /// this should return a PrimFun(CurriedFn) object. pub fn val_of_prim(&self, p: &Symbol) -> Option> { @@ -110,5 +164,19 @@ pub trait Domain: Clone + Debug + PartialEq + Eq + Hash + Send + Sync { fn type_of_dom_val(&self) -> SlowType; fn new_dsl() -> DSL; + + #[cfg(feature = "python")] + fn py_val_to_py(py: Python<'_>, v: Val) -> PyResult> { + // default: not supported for this domain + Err(pyo3::exceptions::PyTypeError::new_err( + "Python bridge not implemented for this domain", + )) + } + + #[cfg(feature = "python")] + fn py_py_to_val(_obj: &Bound<'_, PyAny>) -> Result, String> { + Err("Python bridge not implemented for this domain".into()) + } + } diff --git a/src/eval.rs b/src/eval.rs index ea44495..26a55e9 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -131,7 +131,7 @@ impl CurriedFn { pub fn apply(mut self, arg: Val, handle: &Evaluator) -> VResult { self.partial_args.push_back(arg); if self.partial_args.len() == self.arity { - handle.dsl.productions.get(&self.name).unwrap().fn_ptr.unwrap() (self.partial_args, handle) + handle.dsl.productions.get(&self.name).unwrap().call(self.partial_args, handle) } else { Ok(Val::PrimFun(self)) }