diff --git a/primer/src/Foreword.hs b/primer/src/Foreword.hs index 8a5596617..a35a7b2dd 100644 --- a/primer/src/Foreword.hs +++ b/primer/src/Foreword.hs @@ -1,6 +1,10 @@ module Foreword ( module Protolude, module Unsafe, + insertAt, + adjustAt, + findAndAdjust, + findAndAdjustA, ) where -- In general, we should defer to "Protolude"'s exports and avoid name @@ -43,3 +47,31 @@ import Protolude hiding ( -- We should remove all uses of `unsafeHead`. See: -- https://github.com/hackworthltd/primer/issues/147 import Protolude.Unsafe as Unsafe (unsafeHead) + +-- | Insert an element at some index, returning `Nothing` if it is out of bounds. +insertAt :: Int -> a -> [a] -> Maybe [a] +insertAt n y xs = + if length a == n + then Just $ a ++ [y] ++ b + else Nothing + where + (a, b) = splitAt n xs + +-- | Apply a function to the element at some index, returning `Nothing` if it is out of bounds. +adjustAt :: Int -> (a -> a) -> [a] -> Maybe [a] +adjustAt n f xs = case splitAt n xs of + (a, b : bs) -> Just $ a ++ [f b] ++ bs + _ -> Nothing + +-- | Adjust the first element of the list which satisfies the predicate. +-- Returns `Nothing` if there is no such element. +findAndAdjust :: (a -> Bool) -> (a -> a) -> [a] -> Maybe [a] +findAndAdjust p f = \case + [] -> Nothing + x : xs -> if p x then Just $ f x : xs else (x :) <$> findAndAdjust p f xs + +-- | Like `findAndAdjust`, but in an `Applicative`. +findAndAdjustA :: Applicative m => (a -> Bool) -> (a -> m a) -> [a] -> m (Maybe [a]) +findAndAdjustA p f = \case + [] -> pure Nothing + x : xs -> if p x then Just . (: xs) <$> f x else (x :) <<$>> findAndAdjustA p f xs diff --git a/primer/src/Primer/API.hs b/primer/src/Primer/API.hs index 566246943..9235924cf 100644 --- a/primer/src/Primer/API.hs +++ b/primer/src/Primer/API.hs @@ -329,7 +329,7 @@ instance ToJSON Def viewProg :: App.Prog -> Prog viewProg p = Prog - { types = typeDefName <$> moduleTypes (progModule p) + { types = typeDefName <$> Map.elems (moduleTypes $ progModule p) , defs = ( \d -> Def diff --git a/primer/src/Primer/Action.hs b/primer/src/Primer/Action.hs index d1a8bb5c1..2f53081fb 100644 --- a/primer/src/Primer/Action.hs +++ b/primer/src/Primer/Action.hs @@ -46,6 +46,8 @@ import Primer.Core ( LVarName, LocalName (LocalName, unLocalName), TmVarRef (..), + TyConName, + TyVarName, Type, Type' (..), TypeCache (..), @@ -57,6 +59,7 @@ import Primer.Core ( defName, getID, unsafeMkGlobalName, + unsafeMkLocalName, valConArgs, valConName, valConType, @@ -87,7 +90,7 @@ import Primer.Core.Transform (renameLocalVar, renameTyVar, renameTyVarExpr) import Primer.Core.Utils (forgetTypeIDs, generateTypeIDs) import Primer.JSON import Primer.Module (Module (moduleDefs, moduleTypes)) -import Primer.Name (Name, NameCounter, unName, unsafeMkName) +import Primer.Name (Name, NameCounter, unName) import Primer.Name.Fresh ( isFresh, isFreshTy, @@ -379,6 +382,18 @@ data ProgAction DeleteDef GVarName | -- | Add a new type definition AddTypeDef ASTTypeDef + | -- | Rename the type definition with the given name, and its type constructor + RenameType TyConName Text + | -- | Rename the value constructor with the given name, in the given type + RenameCon TyConName ValConName Text + | -- | Rename the type parameter with the given name, in the given type + RenameTypeParam TyConName TyVarName Text + | -- | Add a value constructor at the given position, in the given type + AddCon TyConName Int Text + | -- | Change the type of the field at the given index of the given constructor + SetConFieldType TyConName ValConName Int (Type' ()) + | -- | Add a new field, at the given index, to the given constructor + AddConField TyConName ValConName Int (Type' ()) | -- | Execute a sequence of actions on the body of the definition BodyAction [Action] | -- | Execute a sequence of actions on the type annotation of the definition @@ -427,7 +442,7 @@ applyActionsToTypeSig smartHoles imports mod def actions = runReaderT go ( buildTypingContext - (concatMap moduleTypes $ mod : imports) + (foldMap moduleTypes $ mod : imports) (foldMap moduleDefs $ mod : imports) smartHoles ) @@ -477,7 +492,7 @@ applyActionsToTypeSig smartHoles imports mod def actions = applyActionsToBody :: (MonadFresh ID m, MonadFresh NameCounter m) => SmartHoles -> - [TypeDef] -> + Map TyConName TypeDef -> Map GVarName Def -> ASTDef -> [Action] -> @@ -512,7 +527,7 @@ applyActionAndCheck ty action z = do -- This is currently only used for tests. -- We may need it in the future for a REPL, where we want to build standalone expressions. -- We take a list of the types that should be in scope for the test. -applyActionsToExpr :: (MonadFresh ID m, MonadFresh NameCounter m) => SmartHoles -> [TypeDef] -> Expr -> [Action] -> m (Either ActionError (Either ExprZ TypeZ)) +applyActionsToExpr :: (MonadFresh ID m, MonadFresh NameCounter m) => SmartHoles -> Map TyConName TypeDef -> Expr -> [Action] -> m (Either ActionError (Either ExprZ TypeZ)) applyActionsToExpr sh typeDefs expr actions = foldM (flip applyActionAndSynth) (focusLoc expr) actions -- apply all actions <&> locToEither @@ -793,7 +808,7 @@ constructLam mx ze = do -- If a name is provided, use that. Otherwise, generate a fresh one. x <- case mx of Nothing -> mkFreshName ze - Just x -> pure (LocalName $ unsafeMkName x) + Just x -> pure (unsafeMkLocalName x) unless (isFresh x (target ze)) $ throwError NameCapture result <- flip replace ze <$> lam x (pure (target ze)) moveExpr Child1 result @@ -804,7 +819,7 @@ constructLAM mx ze = do -- If a name is provided, use that. Otherwise, generate a fresh one. x <- case mx of Nothing -> mkFreshName ze - Just x -> pure (LocalName $ unsafeMkName x) + Just x -> pure (unsafeMkLocalName x) unless (isFresh x (target ze)) $ throwError NameCapture result <- flip replace ze <$> lAM x (pure (target ze)) moveExpr Child1 result @@ -859,7 +874,7 @@ constructLet mx ze = case target ze of -- If a name is provided, use that. Otherwise, generate a fresh one. x <- case mx of Nothing -> mkFreshName ze - Just x -> pure (LocalName $ unsafeMkName x) + Just x -> pure (unsafeMkLocalName x) flip replace ze <$> let_ x emptyHole emptyHole e -> throwError $ NeedEmptyHole (ConstructLet mx) e @@ -869,7 +884,7 @@ constructLetrec mx ze = case target ze of -- If a name is provided, use that. Otherwise, generate a fresh one. x <- case mx of Nothing -> mkFreshName ze - Just x -> pure (LocalName $ unsafeMkName x) + Just x -> pure (unsafeMkLocalName x) flip replace ze <$> letrec x emptyHole tEmptyHole emptyHole e -> throwError $ NeedEmptyHole (ConstructLetrec mx) e @@ -933,7 +948,7 @@ renameLam y ze = case target ze of Lam m x e | unName (unLocalName x) == y -> pure ze | otherwise -> do - let y' = LocalName $ unsafeMkName y + let y' = unsafeMkLocalName y case renameLocalVar x y' e of Just e' -> pure $ replace (Lam m y' e') ze Nothing -> @@ -947,7 +962,7 @@ renameLAM b ze = case target ze of LAM m a e | unName (unLocalName a) == b -> pure ze | otherwise -> do - let b' = LocalName $ unsafeMkName b + let b' = unsafeMkLocalName b case renameTyVarExpr a b' e of Just e' -> pure $ replace (LAM m b' e') ze Nothing -> @@ -961,13 +976,13 @@ renameLet y ze = case target ze of Let m x e1 e2 | unName (unLocalName x) == y -> pure ze | otherwise -> do - let y' = LocalName $ unsafeMkName y + let y' = unsafeMkLocalName y (e1', e2') <- doRename x y' e1 e2 pure $ replace (Let m y' e1' e2') ze Letrec m x e1 t1 e2 | unName (unLocalName x) == y -> pure ze | otherwise -> do - let y' = LocalName $ unsafeMkName y + let y' = unsafeMkLocalName y (e1', e2') <- doRename x y' e1 e2 pure $ replace (Letrec m y' e1' t1 e2') ze _ -> @@ -984,7 +999,7 @@ renameCaseBinding :: forall m. ActionM m => Text -> CaseBindZ -> m CaseBindZ renameCaseBinding y caseBind = updateCaseBind caseBind $ \bind bindings rhs -> do let failure :: Text -> m a failure = throwError . CustomFailure (RenameCaseBinding y) - let y' = LocalName $ unsafeMkName y + let y' = unsafeMkLocalName y -- Check that 'y' doesn't clash with any of the other branch bindings let otherBindings = delete bind bindings @@ -1027,15 +1042,14 @@ constructTCon c zt = case target zt of constructTVar :: ActionM m => Text -> TypeZ -> m TypeZ constructTVar x ast = case target ast of - TEmptyHole{} -> flip replace ast <$> tvar (LocalName $ unsafeMkName x) + TEmptyHole{} -> flip replace ast <$> tvar (unsafeMkLocalName x) _ -> throwError $ CustomFailure (ConstructTVar x) "can only construct tvar in hole" constructTForall :: ActionM m => Maybe Text -> TypeZ -> m TypeZ constructTForall mx zt = do - x <- - LocalName <$> case mx of - Nothing -> mkFreshNameTy zt - Just x -> pure (unsafeMkName x) + x <- case mx of + Nothing -> LocalName <$> mkFreshNameTy zt + Just x -> pure (unsafeMkLocalName x) unless (isFreshTy x $ target zt) $ throwError NameCapture flip replace zt <$> tforall x C.KType (pure (target zt)) @@ -1048,7 +1062,7 @@ renameForall b zt = case target zt of TForall m a k t | unName (unLocalName a) == b -> pure zt | otherwise -> do - let b' = LocalName $ unsafeMkName b + let b' = unsafeMkLocalName b case renameTyVar a b' t of Just t' -> pure $ replace (TForall m b' k t') zt Nothing -> diff --git a/primer/src/Primer/App.hs b/primer/src/Primer/App.hs index 4e4caad7f..7f63fcfee 100644 --- a/primer/src/Primer/App.hs +++ b/primer/src/Primer/App.hs @@ -4,6 +4,7 @@ {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} -- This module defines the high level application functions. @@ -58,12 +59,30 @@ import Data.Aeson ( ) import Data.Bitraversable (bimapM) import Data.Generics.Product (position) +import Data.Generics.Uniplate.Operations (descendM, transform, transformM) import Data.Generics.Uniplate.Zipper ( fromZipper, ) +import Data.List.Extra ((!?)) import qualified Data.Map.Strict as Map import qualified Data.Set as Set -import Optics (re, traverseOf, view, (%), (%~), (.~), (?~), (^.), _Left, _Right) +import Optics ( + Field1 (_1), + Field2 (_2), + ReversibleOptic (re), + over, + toListOf, + traverseOf, + traversed, + view, + (%), + (%~), + (.~), + (?~), + (^.), + _Left, + _Right, + ) import Primer.Action ( Action, ActionError (..), @@ -74,30 +93,39 @@ import Primer.Action ( import Primer.Core ( ASTDef (..), ASTTypeDef (..), + Bind' (Bind), + CaseBranch, + CaseBranch' (CaseBranch), Def (..), Expr, - Expr' (EmptyHole, Var), + Expr' (Case, Con, EmptyHole, Hole, Var), ExprMeta, GVarName, GlobalName (baseName), ID (..), Kind (..), - LocalName (unLocalName), + LocalName (LocalName, unLocalName), Meta (..), PrimDef (..), TmVarRef (GlobalVarRef, LocalVarRef), + TyConName, + TyVarName, Type, Type' (..), TypeDef (..), TypeMeta, ValCon (..), + ValConName, defAST, defName, defPrim, getID, primFunType, qualifyName, + typeDefAST, + typesInExpr, unsafeMkGlobalName, + unsafeMkLocalName, _exprMeta, _exprMetaLens, _exprTypeMeta, @@ -106,8 +134,9 @@ import Primer.Core ( _typeMetaLens, ) import Primer.Core.DSL (create, emptyHole, tEmptyHole) -import Primer.Core.Transform (renameVar) -import Primer.Core.Utils (_freeTmVars, _freeTyVars, _freeVarsTy) +import qualified Primer.Core.DSL as DSL +import Primer.Core.Transform (foldApp, renameVar, unfoldApp, unfoldTApp) +import Primer.Core.Utils (freeVars, _freeTmVars, _freeTyVars, _freeVarsTy) import Primer.Eval (EvalDetail, EvalError) import qualified Primer.Eval as Eval import Primer.EvalFull (Dir, EvalFullError (TimedOut), TerminationBound, evalFull) @@ -132,6 +161,7 @@ import Primer.Typecheck ( checkEverything, checkTypeDefs, mkTypeDefMap, + synth, ) import Primer.Zipper ( ExprZ, @@ -203,15 +233,15 @@ importModules ms = do -- in the imported module are distinct from those already existing in the -- App. p <- gets appProg - checkedImports' <- runExceptT $ checkEverything NoSmartHoles $ CheckEverything{trusted = progImports p, toCheck = ms} - checkedImports <- case checkedImports' of - Left err -> throwError $ ActionError $ ImportFailed () err - Right ci -> pure ci + checkedImports <- + liftError (ActionError . ImportFailed ()) $ + checkEverything NoSmartHoles $ + CheckEverything{trusted = progImports p, toCheck = ms} let p' = p & #progImports %~ (<> checkedImports) modify (\a -> a{appProg = p'}) -- | Get all type definitions from all modules (including imports) -allTypes :: Prog -> [TypeDef] +allTypes :: Prog -> Map TyConName TypeDef allTypes p = foldMap moduleTypes $ progModule p : progImports p -- | Get all definitions from all modules (including imports) @@ -232,7 +262,7 @@ addTypeDef :: ASTTypeDef -> Prog -> Prog addTypeDef t p = let mod = progModule p tydefs = moduleTypes mod - tydefs' = tydefs <> [TypeDefAST t] + tydefs' = tydefs <> mkTypeDefMap [TypeDefAST t] mod' = mod{moduleTypes = tydefs'} in p{progModule = mod'} @@ -285,6 +315,15 @@ data ProgError | DefNotFound GVarName | DefAlreadyExists GVarName | DefInUse GVarName + | TypeDefIsPrim TyConName + | TypeDefNotFound TyConName + | TypeDefAlreadyExists TyConName + | ConNotFound ValConName + | ConAlreadyExists ValConName + | ParamNotFound TyVarName + | ParamAlreadyExists TyVarName + | TyConParamClash Name + | ValConParamClash Name | ActionError ActionError | EvalError EvalError | -- | Currently copy/paste is only exposed in the frontend via select @@ -299,6 +338,7 @@ data ProgError -- (However, this is not entirely true currently, see -- https://github.com/hackworthltd/primer/issues/3) TypeDefError Text + | IndexOutOfRange Int deriving (Eq, Show, Generic) deriving (FromJSON, ToJSON) via VJSON ProgError @@ -422,7 +462,7 @@ handleEvalRequest req = do handleEvalFullRequest :: MonadEditApp m => EvalFullReq -> m EvalFullResp handleEvalFullRequest (EvalFullReq{evalFullReqExpr, evalFullCxtDir, evalFullMaxSteps}) = do prog <- gets appProg - result <- evalFull (mkTypeDefMap $ allTypes prog) (allDefs prog) evalFullMaxSteps evalFullCxtDir evalFullReqExpr + result <- evalFull (allTypes prog) (allDefs prog) evalFullMaxSteps evalFullCxtDir evalFullReqExpr pure $ case result of Left (TimedOut e) -> EvalFullRespTimedOut e Right nf -> EvalFullRespNormal nf @@ -442,9 +482,10 @@ applyProgAction prog mdefName = \case -- Run a full TC solely to ensure that no references to the removed id -- remain. This is rather inefficient and could be improved in the -- future. - runExceptT (checkEverything @TypeError NoSmartHoles CheckEverything{trusted = progImports prog, toCheck = [mod']}) >>= \case - Left _ -> throwError $ DefInUse d - Right _ -> pure () + void . liftError (const $ DefInUse d) $ + checkEverything @TypeError + NoSmartHoles + CheckEverything{trusted = progImports prog, toCheck = [mod']} pure (prog', Nothing) DeleteDef d -> throwError $ DefNotFound d RenameDef d nameStr -> case lookupASTDef d (moduleDefs $ progModule prog) of @@ -482,12 +523,8 @@ applyProgAction prog mdefName = \case let def = ASTDef name expr ty pure (addDef def prog{progSelection = Just $ Selection name Nothing}, Just name) AddTypeDef td -> do - runExceptT @TypeError - ( runReaderT - (checkTypeDefs [TypeDefAST td]) - (buildTypingContext (allTypes prog) mempty NoSmartHoles) - ) - >>= \case + (addTypeDef td prog, mdefName) + <$ liftError -- The frontend should never let this error case happen, -- so we just dump out a raw string for debugging/logging purposes -- (This is not currently true! We should synchronise the frontend with @@ -495,8 +532,220 @@ applyProgAction prog mdefName = \case -- data T (T : *) = T -- but the TC rejects it. -- see https://github.com/hackworthltd/primer/issues/3) - Left err -> throwError $ TypeDefError $ show err - Right _ -> pure (addTypeDef td prog, mdefName) + (TypeDefError . show @TypeError) + ( runReaderT + (checkTypeDefs $ mkTypeDefMap [TypeDefAST td]) + (buildTypingContext (allTypes prog) mempty NoSmartHoles) + ) + RenameType old (unsafeMkGlobalName -> new) -> + (,Nothing) <$> do + traverseOf + #progModule + ( traverseOf #moduleTypes (updateType <=< pure . updateRefsInTypes) + <=< pure . over (#moduleDefs % traversed % #_DefAST) (updateDefBody . updateDefType) + ) + prog + where + updateType m = do + d0 <- + -- NB We do not allow primitive types to be renamed. + -- To relax this, we'd have to be careful about how it interacts with type-checking of primitive literals. + maybe (throwError $ TypeDefIsPrim old) pure . typeDefAST + =<< maybe (throwError $ TypeDefNotFound old) pure (Map.lookup old m) + -- TODO we should really check this against _all_ modules, but we will very shortly be adding namespacing + when (Map.member new m) $ throwError $ TypeDefAlreadyExists new + let nameRaw = baseName new + when (nameRaw `elem` map (unLocalName . fst) (astTypeDefParameters d0)) $ throwError $ TyConParamClash nameRaw + pure $ Map.insert new (TypeDefAST $ d0 & #astTypeDefName .~ new) $ Map.delete old m + updateRefsInTypes = + over + (traversed % #_TypeDefAST % #astTypeDefConstructors % traversed % #valConArgs % traversed) + $ transform $ over (#_TCon % _2) updateName + updateDefType = + over + #astDefType + $ transform $ over (#_TCon % _2) updateName + updateDefBody = + over + #astDefExpr + $ transform $ over typesInExpr $ transform $ over (#_TCon % _2) updateName + updateName n = if n == old then new else n + RenameCon type_ old (unsafeMkGlobalName -> new) -> + (,Nothing) <$> do + when (new `elem` allConNames prog) $ throwError $ ConAlreadyExists new + traverseOf + #progModule + ( traverseOf #moduleTypes updateType + <=< traverseOf #moduleDefs (pure . updateDefs) + ) + prog + where + updateType = + alterTypeDef + ( traverseOf + #astTypeDefConstructors + ( maybe (throwError $ ConNotFound old) pure + . findAndAdjust ((== old) . valConName) (#valConName .~ new) + ) + ) + type_ + updateDefs = + over (traversed % #_DefAST % #astDefExpr) $ + transform $ over (#_Con % _2) updateName + updateName n = if n == old then new else n + RenameTypeParam type_ old (unsafeMkLocalName -> new) -> + (,Nothing) + <$> traverseOf + #progModule + (traverseOf #moduleTypes updateType) + prog + where + updateType = + alterTypeDef + (pure . updateConstructors <=< updateParam) + type_ + updateParam def = do + when (new `elem` map fst (astTypeDefParameters def)) $ throwError $ ParamAlreadyExists new + let nameRaw = unLocalName new + when (nameRaw == baseName (astTypeDefName def)) $ throwError $ TyConParamClash nameRaw + when (nameRaw `elem` map (baseName . valConName) (astTypeDefConstructors def)) $ throwError $ ValConParamClash nameRaw + def + & traverseOf + #astTypeDefParameters + ( maybe (throwError $ ParamNotFound old) pure + . findAndAdjust ((== old) . fst) (_1 .~ new) + ) + updateConstructors = + over + ( #astTypeDefConstructors + % traversed + % #valConArgs + % traversed + ) + $ over _freeVarsTy $ \(_, v) -> TVar () $ updateName v + updateName n = if n == old then new else n + AddCon type_ index (unsafeMkGlobalName -> con) -> + (,Nothing) + <$> do + when (con `elem` allConNames prog) $ throwError $ ConAlreadyExists con + traverseOf + #progModule + ( traverseOf + (#moduleDefs % traversed % #_DefAST % #astDefExpr) + updateDefs + <=< traverseOf + #moduleTypes + updateType + ) + prog + where + updateDefs = transformCaseBranches prog type_ $ \bs -> do + m' <- DSL.meta + maybe (throwError $ IndexOutOfRange index) pure $ insertAt index (CaseBranch con [] (EmptyHole m')) bs + updateType = + alterTypeDef + ( traverseOf + #astTypeDefConstructors + (maybe (throwError $ IndexOutOfRange index) pure . insertAt index (ValCon con [])) + ) + type_ + SetConFieldType type_ con index new -> + (,Nothing) + <$> traverseOf + #progModule + ( traverseOf #moduleDefs updateDefs + <=< traverseOf #moduleTypes updateType + ) + prog + where + updateType = + alterTypeDef + ( traverseOf #astTypeDefConstructors $ + maybe (throwError $ ConNotFound con) pure + <=< findAndAdjustA + ((== con) . valConName) + ( traverseOf + #valConArgs + (maybe (throwError $ IndexOutOfRange index) pure . adjustAt index (const new)) + ) + ) + type_ + updateDefs = traverseOf (traversed % #_DefAST % #astDefExpr) (updateDecons <=< updateCons) + updateCons e = case unfoldApp e of + (e'@(Con _ con'), args) | con' == con -> do + m' <- DSL.meta + case adjustAt index (Hole m') args of + Just args' -> foldApp e' =<< traverse (descendM updateCons) args' + Nothing -> do + -- The constructor is not applied as far as the changed field, + -- so the full application still typechecks, but its type has changed. + -- Thus, we put the whole thing in to a hole. + Hole <$> DSL.meta <*> (foldApp e' =<< traverse (descendM updateCons) args) + _ -> + -- NB we can't use `transformM` here because we'd end up seeing incomplete applications before full ones + descendM updateCons e + updateDecons = transformCaseBranches prog type_ $ + traverse $ \cb@(CaseBranch vc binds e) -> + if vc == con + then do + Bind _ v <- maybe (throwError $ IndexOutOfRange index) pure $ binds !? index + CaseBranch vc binds + <$> + -- TODO a custom traversal could be more efficient - reusing `_freeTmVars` means that we continue in + -- to parts of the tree where `v` is shadowed, and thus where the traversal will never have any effect + traverseOf + _freeTmVars + ( \(m, v') -> + if v' == v + then Hole <$> DSL.meta <*> pure (Var m $ LocalVarRef v') + else pure (Var m $ LocalVarRef v') + ) + e + else pure cb + AddConField type_ con index new -> + (,Nothing) + <$> traverseOf + #progModule + ( traverseOf #moduleDefs updateDefs + <=< traverseOf #moduleTypes updateType + ) + prog + where + updateType = + alterTypeDef + ( traverseOf #astTypeDefConstructors $ + maybe (throwError $ ConNotFound con) pure + <=< findAndAdjustA + ((== con) . valConName) + ( traverseOf + #valConArgs + (maybe (throwError $ IndexOutOfRange index) pure . insertAt index new) + ) + ) + type_ + updateDefs = traverseOf (traversed % #_DefAST % #astDefExpr) (updateDecons <=< updateCons) + updateCons e = case unfoldApp e of + (e'@(Con _ con'), args) | con' == con -> do + m' <- DSL.meta + case insertAt index (EmptyHole m') args of + Just args' -> foldApp e' =<< traverse (descendM updateCons) args' + Nothing -> + -- The constructor is not applied as far as the field immediately prior to the new one, + -- so the full application still typechecks, but its type has changed. + -- Thus, we put the whole thing in to a hole. + Hole <$> DSL.meta <*> (foldApp e' =<< traverse (descendM updateCons) args) + _ -> + -- NB we can't use `transformM` here because we'd end up seeing incomplete applications before full ones + descendM updateCons e + updateDecons = transformCaseBranches prog type_ $ + traverse $ \cb@(CaseBranch vc binds e) -> + if vc == con + then do + m' <- DSL.meta + newName <- LocalName <$> freshName (freeVars e) + binds' <- maybe (throwError $ IndexOutOfRange index) pure $ insertAt index (Bind m' newName) binds + pure $ CaseBranch vc binds' e + else pure cb BodyAction actions -> do withDef mdefName prog $ \def -> do smartHoles <- gets $ progSmartHoles . appProg @@ -684,7 +933,7 @@ newEmptyProg = { progImports = mempty , progModule = Module - { moduleTypes = [] + { moduleTypes = mempty , moduleDefs = Map.singleton (defName def) def } , progSelection = Nothing @@ -892,11 +1141,7 @@ tcWholeProg p = , selectedNode = updatedNode } pure $ p'{progSelection = newSel} - in do - x <- runExceptT $ runReaderT tc $ buildTypingContext (allTypes p) (allDefs p) (progSmartHoles p) - case x of - Left e -> throwError $ ActionError e - Right prog -> pure prog + in liftError ActionError $ runReaderT tc $ progCxt p copyPasteBody :: MonadEditApp m => Prog -> (GVarName, ID) -> GVarName -> [Action] -> m Prog copyPasteBody p (fromDefName, fromId) toDefName setup = do @@ -990,14 +1235,72 @@ copyPasteBody p (fromDefName, fromId) toDefName setup = do lookupASTDef :: GVarName -> Map GVarName Def -> Maybe ASTDef lookupASTDef name = defAST <=< Map.lookup name -defaultTypeDefs :: [TypeDef] +alterTypeDef :: + MonadEditApp m => + (ASTTypeDef -> m ASTTypeDef) -> + TyConName -> + Map TyConName TypeDef -> + m (Map TyConName TypeDef) +alterTypeDef f type_ = + Map.alterF + ( maybe + (throwError $ TypeDefNotFound type_) + ( maybe + (throwError $ TypeDefIsPrim type_) + (map (Just . TypeDefAST) . f) + . typeDefAST + ) + ) + type_ + +-- | Apply a bottom-up transformation to all branches of case expressions on the given type. +transformCaseBranches :: + MonadEditApp m => + Prog -> + TyConName -> + ([CaseBranch] -> m [CaseBranch]) -> + Expr -> + m Expr +transformCaseBranches prog type_ f = transformM $ \case + Case m scrut bs -> do + scrutType <- + fst + <$> runReaderT + (liftError (ActionError . TypeError) $ synth scrut) + (progCxt prog) + Case m scrut + <$> if fst (unfoldTApp scrutType) == TCon () type_ + then f bs + else pure bs + e -> pure e + +progCxt :: Prog -> Cxt +progCxt p = buildTypingContext (allTypes p) (allDefs p) (progSmartHoles p) + +-- | Run a computation in some context whose errors can be promoted to `ProgError`. +liftError :: MonadEditApp m => (e -> ProgError) -> ExceptT e m b -> m b +liftError f = runExceptT >=> either (throwError . f) pure + +allConNames :: Prog -> [ValConName] +allConNames = + toListOf $ + #progModule + % #moduleTypes + % traversed + % #_TypeDefAST + % #astTypeDefConstructors + % traversed + % #valConName + +defaultTypeDefs :: Map TyConName TypeDef defaultTypeDefs = - map - TypeDefAST - [boolDef, natDef, listDef, maybeDef, pairDef, eitherDef] - <> map - TypeDefPrim - (Map.elems allPrimTypeDefs) + mkTypeDefMap $ + map + TypeDefAST + [boolDef, natDef, listDef, maybeDef, pairDef, eitherDef] + <> map + TypeDefPrim + (Map.elems allPrimTypeDefs) -- | A definition of the Bool type boolDef :: ASTTypeDef diff --git a/primer/src/Primer/Core.hs b/primer/src/Primer/Core.hs index 7f2c44937..b82399bd3 100644 --- a/primer/src/Primer/Core.hs +++ b/primer/src/Primer/Core.hs @@ -36,6 +36,7 @@ module Primer.Core ( GVarName, LocalNameKind (..), LocalName (LocalName, unLocalName), + unsafeMkLocalName, LVarName, TyVarName, Type, @@ -73,6 +74,7 @@ module Primer.Core ( _typeMetaLens, bindName, _bindMeta, + typesInExpr, ) where import Foreword @@ -83,7 +85,21 @@ import Data.Data (Data) import Data.Generics.Product import Data.Generics.Uniplate.Data () import Data.Generics.Uniplate.Zipper (Zipper, hole, replaceHole) -import Optics (AffineFold, Lens, Lens', Traversal, afailing, lens, set, view, (%)) +import Optics ( + AffineFold, + Lens, + Lens', + Traversal, + Traversal', + adjoin, + afailing, + lens, + set, + view, + (%), + _3, + _4, + ) import Primer.JSON import Primer.Name (Name, unsafeMkName) @@ -187,6 +203,9 @@ newtype LocalName (k :: LocalNameKind) = LocalName {unLocalName :: Name} deriving (IsString) via Name deriving (FromJSON, ToJSON, FromJSONKey, ToJSONKey) via Name +unsafeMkLocalName :: Text -> LocalName k +unsafeMkLocalName = LocalName . unsafeMkName + type LVarName = LocalName 'ATmVar type TyVarName = LocalName 'ATyVar @@ -346,6 +365,14 @@ data Type' a deriving (Eq, Show, Data, Generic) deriving (FromJSON, ToJSON) via VJSON (Type' a) +-- | Note that this does not recurse in to sub-expressions or sub-types. +typesInExpr :: Traversal' (Expr' a b) (Type' b) +typesInExpr = + #_Ann % _3 + `adjoin` #_APP % _3 + `adjoin` #_LetType % _3 + `adjoin` #_Letrec % _4 + -- | A traversal over the metadata of a type _typeMeta :: Traversal (Type' a) (Type' b) a b _typeMeta = param @0 diff --git a/primer/src/Primer/Core/Transform.hs b/primer/src/Primer/Core/Transform.hs index 9c0551673..d2a994fcb 100644 --- a/primer/src/Primer/Core/Transform.hs +++ b/primer/src/Primer/Core/Transform.hs @@ -4,17 +4,21 @@ module Primer.Core.Transform ( renameTyVar, renameTyVarExpr, unfoldApp, + foldApp, unfoldAPP, + unfoldTApp, unfoldFun, removeAnn, ) where import Foreword +import Control.Monad.Fresh (MonadFresh) import Data.Data (Data) import Data.Generics.Uniplate.Data (descendM) import qualified Data.List.NonEmpty as NE -import Primer.Core (CaseBranch' (..), Expr' (..), LVarName, LocalName (unLocalName), TmVarRef (..), TyVarName, Type' (..), bindName, varRefName) +import Primer.Core (CaseBranch' (..), Expr, Expr' (..), ID, LVarName, LocalName (unLocalName), TmVarRef (..), TyVarName, Type' (..), bindName, varRefName) +import Primer.Core.DSL (meta) -- AST transformations. -- This module contains global transformations on expressions and types, in @@ -104,6 +108,12 @@ unfoldApp = second reverse . go go (App _ f x) = let (g, args) = go f in (g, x : args) go e = (e, []) +-- | Fold an application head and a list of arguments in to a single expression. +foldApp :: (Foldable t, MonadFresh ID m) => Expr -> t Expr -> m Expr +foldApp = foldM $ \a b -> do + m <- meta + pure $ App m a b + -- | Unfold a nested term-level type application into the application head and a list of arguments. unfoldAPP :: Expr' a b -> (Expr' a b, [Type' b]) unfoldAPP = second reverse . go @@ -111,6 +121,13 @@ unfoldAPP = second reverse . go go (APP _ f x) = let (g, args) = go f in (g, x : args) go e = (e, []) +-- | Unfold a nested type-level application into the application head and a list of arguments. +unfoldTApp :: Type' a -> (Type' a, [Type' a]) +unfoldTApp = second reverse . go + where + go (TApp _ f x) = let (g, args) = go f in (g, x : args) + go e = (e, []) + -- | Split a function type into an array of argument types and the result type. -- Takes two arguments: the lhs and rhs of the topmost function node. unfoldFun :: Type' a -> Type' a -> (NonEmpty (Type' a), Type' a) diff --git a/primer/src/Primer/Module.hs b/primer/src/Primer/Module.hs index 23fea2f15..7e67ab45e 100644 --- a/primer/src/Primer/Module.hs +++ b/primer/src/Primer/Module.hs @@ -1,11 +1,11 @@ module Primer.Module (Module (..)) where import Foreword -import Primer.Core (Def, GlobalName, GlobalNameKind (ADefName), TypeDef) +import Primer.Core (Def, GlobalName, GlobalNameKind (ADefName, ATyCon), TypeDef) import Primer.JSON data Module = Module - { moduleTypes :: [TypeDef] + { moduleTypes :: Map (GlobalName 'ATyCon) TypeDef , moduleDefs :: Map (GlobalName 'ADefName) Def -- The current program: a set of definitions indexed by Name } deriving (Eq, Show, Generic) diff --git a/primer/src/Primer/Typecheck.hs b/primer/src/Primer/Typecheck.hs index de98ec1a9..acecaf5dc 100644 --- a/primer/src/Primer/Typecheck.hs +++ b/primer/src/Primer/Typecheck.hs @@ -225,8 +225,8 @@ extendLocalCxtTys x cxt = cxt{localCxt = Map.fromList (bimap unLocalName K <$> x extendGlobalCxt :: [(GVarName, Type)] -> Cxt -> Cxt extendGlobalCxt globals cxt = cxt{globalCxt = Map.fromList globals <> globalCxt cxt} -extendTypeDefCxt :: [TypeDef] -> Cxt -> Cxt -extendTypeDefCxt typedefs cxt = cxt{typeDefs = mkTypeDefMap typedefs <> typeDefs cxt} +extendTypeDefCxt :: [(TyConName, TypeDef)] -> Cxt -> Cxt +extendTypeDefCxt typedefs cxt = cxt{typeDefs = Map.fromList typedefs <> typeDefs cxt} localTmVars :: Cxt -> Map LVarName Type localTmVars = M.mapKeys LocalName . M.mapMaybe (\case T t -> Just t; K _ -> Nothing) . localCxt @@ -248,10 +248,10 @@ initialCxt sh = } -- | Construct an initial typing context, with all given definitions in scope as global variables. -buildTypingContext :: [TypeDef] -> Map GVarName Def -> SmartHoles -> Cxt +buildTypingContext :: Map TyConName TypeDef -> Map GVarName Def -> SmartHoles -> Cxt buildTypingContext tydefs defs sh = let globals = Map.elems $ fmap (\def -> (defName def, forgetTypeIDs (defType def))) defs - in extendTypeDefCxt tydefs $ extendGlobalCxt globals $ initialCxt sh + in extendTypeDefCxt (Map.toList tydefs) $ extendGlobalCxt globals $ initialCxt sh -- | Create a mapping of name to typedef for fast lookup. -- Ensures that @typeDefName (mkTypeDefMap ! n) == n@ @@ -319,13 +319,13 @@ checkTypeDefsMap :: m () checkTypeDefsMap tds = if and $ M.mapWithKey (\n td -> n == typeDefName td) tds - then checkTypeDefs $ M.elems tds + then checkTypeDefs tds else throwError' $ InternalError "Inconsistent names in a Map TyConName TypeDef" -- | Check all type definitions, as one recursive group, in some monadic environment checkTypeDefs :: TypeM e m => - [TypeDef] -> + Map TyConName TypeDef -> m () checkTypeDefs tds = do existingTypes <- asks $ Map.elems . typeDefs @@ -333,18 +333,18 @@ checkTypeDefs tds = do -- errors here are "internal errors" and should never be seen. -- (This is not quite true, see -- https://github.com/hackworthltd/primer/issues/3) - assert (distinct $ map typeDefName $ existingTypes <> tds) "Duplicate-ly-named TypeDefs" + assert (distinct $ map typeDefName $ existingTypes <> Map.elems tds) "Duplicate-ly-named TypeDefs" -- Note that constructors are synthesisable, so their names must be globally -- unique. We need to be able to work out the type of @TCon "C"@ without any -- extra information. - let atds = mapMaybe typeDefAST tds + let atds = mapMaybe typeDefAST $ Map.elems tds let allAtds = mapMaybe typeDefAST existingTypes <> atds assert (distinct $ concatMap (map valConName . astTypeDefConstructors) allAtds) "Duplicate-ly-named constructor (perhaps in different typedefs)" -- Note that these checks only apply to non-primitives: -- duplicate type names are checked elsewhere, kinds are correct by construction, and there are no constructors. - local (extendTypeDefCxt tds) $ mapM_ checkTypeDef atds + local (extendTypeDefCxt $ Map.toList tds) $ mapM_ checkTypeDef atds where -- In the core, we have many different namespaces, so the only name-clash -- checking we must do is @@ -405,19 +405,19 @@ checkEverything :: checkEverything sh CheckEverything{trusted, toCheck} = let cxt = buildTypingContext - (concatMap moduleTypes trusted) + (foldMap moduleTypes trusted) (foldMap moduleDefs trusted) sh in flip runReaderT cxt $ do -- Check that the definition map has the right keys for_ toCheck $ \m -> flip Map.traverseWithKey (moduleDefs m) $ \n d -> unless (n == defName d) $ throwError' $ InternalError "Inconsistant names in moduleDefs map" - checkTypeDefs $ concatMap moduleTypes toCheck + checkTypeDefs $ foldMap moduleTypes toCheck let newTypes = foldMap moduleTypes toCheck newDefs = foldMap (\d -> [(defName d, forgetTypeIDs $ defType d)]) $ foldMap moduleDefs toCheck - local (extendGlobalCxt newDefs . extendTypeDefCxt newTypes) $ + local (extendGlobalCxt newDefs . extendTypeDefCxt (Map.toList newTypes)) $ traverseOf (traversed % #moduleDefs % traversed) checkDef toCheck -- | Typecheck a definition. diff --git a/primer/test/Tests/Action/Prog.hs b/primer/test/Tests/Action/Prog.hs index 76b47d3cf..87975204a 100644 --- a/primer/test/Tests/Action/Prog.hs +++ b/primer/test/Tests/Action/Prog.hs @@ -54,12 +54,14 @@ import Primer.Core ( Kind (KType), Meta (..), TmVarRef (..), + TyConName, Type' (..), TypeDef (..), ValCon (..), defAST, defName, getID, + typeDefAST, _exprMeta, _exprTypeMeta, _id, @@ -75,6 +77,8 @@ import Primer.Core.DSL ( con, create, emptyHole, + gvar, + hole, lAM, lam, lvar, @@ -85,8 +89,10 @@ import Primer.Core.DSL ( tfun, tvar, ) +import Primer.Core.Utils (forgetIDs) import Primer.Module (Module (moduleDefs, moduleTypes)) import Primer.Name +import Primer.Typecheck (mkTypeDefMap) import Test.Tasty.HUnit (Assertion, assertBool, assertFailure, (@=?), (@?=)) import TestM (TestM, evalTestM) import TestUtils (withPrimDefs) @@ -232,7 +238,7 @@ unit_create_typedef = in progActionTest defaultEmptyProg [AddTypeDef lst, AddTypeDef tree] $ expectSuccess $ \_ prog' -> do - case moduleTypes $ progModule prog' of + case Map.elems $ moduleTypes $ progModule prog' of [lst', tree'] -> do TypeDefAST lst @=? lst' TypeDefAST tree @=? tree' @@ -370,7 +376,7 @@ unit_create_typedef_8 = , astTypeDefNameHints = [] } in progActionTest defaultEmptyProg [AddTypeDef td] $ - expectSuccess $ \_ prog' -> moduleTypes (progModule prog') @?= [TypeDefAST td] + expectSuccess $ \_ prog' -> Map.elems (moduleTypes (progModule prog')) @?= [TypeDefAST td] -- Allow clash between type name and constructor name across types unit_create_typedef_9 :: Assertion @@ -390,7 +396,7 @@ unit_create_typedef_9 = , astTypeDefNameHints = [] } in progActionTest defaultEmptyProg [AddTypeDef td1, AddTypeDef td2] $ - expectSuccess $ \_ prog' -> moduleTypes (progModule prog') @?= [TypeDefAST td1, TypeDefAST td2] + expectSuccess $ \_ prog' -> Map.elems (moduleTypes (progModule prog')) @?= [TypeDefAST td2, TypeDefAST td1] unit_construct_arrow_in_sig :: Assertion unit_construct_arrow_in_sig = @@ -708,6 +714,318 @@ unit_rename_def_capture = progActionTest defaultEmptyProg [MoveToDef "other", BodyAction [ConstructLam $ Just "foo"], RenameDef "main" "foo"] $ expectError (@?= ActionError NameCapture) +unit_RenameType :: Assertion +unit_RenameType = + progActionTest + ( defaultProgEditableTypeDefs $ + sequence + [ do + x <- emptyHole `ann` (tcon "T" `tapp` tcon "Bool") + ASTDef "def" x <$> tEmptyHole + ] + ) + [RenameType "T" "T'"] + $ expectSuccess $ \_ prog' -> do + td <- findTypeDef "T'" prog' + astTypeDefName td @?= "T'" + def <- findDef "def" prog' + forgetIDs (astDefExpr def) + @?= forgetIDs + ( fst . create $ + emptyHole `ann` (tcon "T'" `tapp` tcon "Bool") + ) + +unit_RenameType_clash :: Assertion +unit_RenameType_clash = + progActionTest + (defaultProgEditableTypeDefs $ pure []) + [RenameType "T" "Int"] + $ expectError (@?= TypeDefAlreadyExists "Int") + +unit_RenameCon :: Assertion +unit_RenameCon = + progActionTest + ( defaultProgEditableTypeDefs $ + sequence + [ do + x <- + hole + ( hole + (con "A") + ) + ASTDef "def" x <$> tEmptyHole + ] + ) + [RenameCon "T" "A" "A'"] + $ expectSuccess $ \_ prog' -> do + td <- findTypeDef "T" prog' + astTypeDefConstructors td + @?= [ ValCon "A'" [TCon () "Bool", TCon () "Bool", TCon () "Bool"] + , ValCon "B" [TVar () "b"] + ] + def <- findDef "def" prog' + forgetIDs (astDefExpr def) + @?= forgetIDs + ( fst . create $ + hole + ( hole + (con "A'") + ) + ) + +unit_RenameCon_clash :: Assertion +unit_RenameCon_clash = + progActionTest + ( defaultProgEditableTypeDefs $ + sequence + [ do + x <- + hole + ( hole + (con "A") + ) + ASTDef "def" x <$> tEmptyHole + ] + ) + [RenameCon "T" "A" "True"] + $ expectError (@?= ConAlreadyExists "True") + +unit_RenameTypeParam :: Assertion +unit_RenameTypeParam = + progActionTest + (defaultProgEditableTypeDefs $ pure []) + [RenameTypeParam "T" "b" "b'"] + $ expectSuccess $ \_ prog' -> do + td <- findTypeDef "T" prog' + astTypeDefParameters td @?= [("a", KType), ("b'", KType)] + astTypeDefConstructors td + @?= [ ValCon "A" [TCon () "Bool", TCon () "Bool", TCon () "Bool"] + , ValCon "B" [TVar () "b'"] + ] + +unit_RenameTypeParam_clash :: Assertion +unit_RenameTypeParam_clash = + progActionTest + (defaultProgEditableTypeDefs $ pure []) + [RenameTypeParam "T" "a" "b"] + $ expectError (@?= ParamAlreadyExists "b") + +unit_AddCon :: Assertion +unit_AddCon = + progActionTest + ( defaultProgEditableTypeDefs $ + sequence + [ do + x <- + case_ + (emptyHole `ann` (tcon "T" `tapp` tcon "Bool" `tapp` tcon "Int")) + [ branch "A" [] emptyHole + , branch "B" [] emptyHole + ] + ASTDef "def" x <$> tEmptyHole + ] + ) + [AddCon "T" 1 "C"] + $ expectSuccess $ \_ prog' -> do + td <- findTypeDef "T" prog' + astTypeDefConstructors td + @?= [ ValCon "A" [TCon () "Bool", TCon () "Bool", TCon () "Bool"] + , ValCon "C" [] + , ValCon "B" [TVar () "b"] + ] + def <- findDef "def" prog' + forgetIDs (astDefExpr def) + @?= forgetIDs + ( fst . create $ + case_ + (emptyHole `ann` (tcon "T" `tapp` tcon "Bool" `tapp` tcon "Int")) + [ branch "A" [] emptyHole + , branch "C" [] emptyHole + , branch "B" [] emptyHole + ] + ) + +unit_SetConFieldType :: Assertion +unit_SetConFieldType = + progActionTest + ( defaultProgEditableTypeDefs . sequence . pure $ do + x <- con "A" `app` lvar "x" `app` (gvar "y" `ann` tcon "Bool") + ASTDef "def" x <$> tEmptyHole + ) + [SetConFieldType "T" "A" 1 $ TCon () "Int"] + $ expectSuccess $ \_ prog' -> do + td <- findTypeDef "T" prog' + astTypeDefConstructors td + @?= [ ValCon "A" [TCon () "Bool", TCon () "Int", TCon () "Bool"] + , ValCon "B" [TVar () "b"] + ] + def <- findDef "def" prog' + forgetIDs (astDefExpr def) + @?= forgetIDs + ( fst . create $ + con "A" `app` lvar "x" `app` hole (gvar "y" `ann` tcon "Bool") + ) + +unit_SetConFieldType_partial_app :: Assertion +unit_SetConFieldType_partial_app = + progActionTest + ( defaultProgEditableTypeDefs $ do + x <- con "A" `app` lvar "x" + sequence + [ ASTDef "def" x <$> tcon "T" + ] + ) + [SetConFieldType "T" "A" 1 $ TCon () "Int"] + $ expectSuccess $ \_ prog' -> do + def <- findDef "def" prog' + forgetIDs (astDefExpr def) + @?= forgetIDs + ( fst . create $ + hole $ + con "A" `app` lvar "x" + ) + +unit_SetConFieldType_case :: Assertion +unit_SetConFieldType_case = + progActionTest + ( defaultProgEditableTypeDefs $ do + x <- + case_ + (emptyHole `ann` (tcon "T" `tapp` tEmptyHole `tapp` tEmptyHole)) + [ branch + "A" + [("x", Nothing), ("y", Nothing), ("z", Nothing)] + (lvar "y") + , branch "B" [] emptyHole + ] + sequence + [ ASTDef "def" x <$> tcon "Bool" + ] + ) + [SetConFieldType "T" "A" 1 $ TCon () "Int"] + $ expectSuccess $ \_ prog' -> do + def <- findDef "def" prog' + forgetIDs (astDefExpr def) + @?= forgetIDs + ( fst . create $ + case_ + (emptyHole `ann` (tcon "T" `tapp` tEmptyHole `tapp` tEmptyHole)) + [ branch + "A" + [("x", Nothing), ("y", Nothing), ("z", Nothing)] + (hole $ lvar "y") + , branch "B" [] emptyHole + ] + ) + +unit_SetConFieldType_shadow :: Assertion +unit_SetConFieldType_shadow = + progActionTest + ( defaultProgEditableTypeDefs $ do + x <- + case_ + (emptyHole `ann` (tcon "T" `tapp` tEmptyHole `tapp` tEmptyHole)) + [ branch + "A" + [("x", Nothing), ("y", Nothing), ("z", Nothing)] + (lam "y" (lvar "y") `app` lvar "y") + , branch "B" [] emptyHole + ] + sequence + [ ASTDef "def" x <$> tcon "Bool" + ] + ) + [SetConFieldType "T" "A" 1 $ TCon () "Int"] + $ expectSuccess $ \_ prog' -> do + def <- findDef "def" prog' + forgetIDs (astDefExpr def) + @?= forgetIDs + ( fst . create $ + case_ + (emptyHole `ann` (tcon "T" `tapp` tEmptyHole `tapp` tEmptyHole)) + [ branch + "A" + [("x", Nothing), ("y", Nothing), ("z", Nothing)] + -- only the free `y` should be put in to a hole + (lam "y" (lvar "y") `app` hole (lvar "y")) + , branch "B" [] emptyHole + ] + ) + +unit_AddConField :: Assertion +unit_AddConField = + progActionTest + ( defaultProgEditableTypeDefs $ do + x <- con "A" `app` con "True" + sequence + [ ASTDef "def" x <$> tEmptyHole + ] + ) + [AddConField "T" "A" 1 $ TCon () "Int"] + $ expectSuccess $ \_ prog' -> do + td <- findTypeDef "T" prog' + astTypeDefConstructors td + @?= [ ValCon "A" [TCon () "Bool", TCon () "Int", TCon () "Bool", TCon () "Bool"] + , ValCon "B" [TVar () "b"] + ] + def <- findDef "def" prog' + forgetIDs (astDefExpr def) + @?= forgetIDs + ( fst . create $ + con "A" `app` con "True" `app` emptyHole + ) + +unit_AddConField_partial_app :: Assertion +unit_AddConField_partial_app = + progActionTest + ( defaultProgEditableTypeDefs $ do + x <- con "A" `app` con "True" + sequence + [ ASTDef "def" x <$> tEmptyHole + ] + ) + [AddConField "T" "A" 2 $ TCon () "Int"] + $ expectSuccess $ \_ prog' -> do + def <- findDef "def" prog' + forgetIDs (astDefExpr def) + @?= forgetIDs + ( fst . create $ + hole $ con "A" `app` con "True" + ) + +unit_AddConField_case :: Assertion +unit_AddConField_case = + progActionTest + ( defaultProgEditableTypeDefs $ do + x <- + case_ + (emptyHole `ann` (tcon "T" `tapp` tEmptyHole `tapp` tEmptyHole)) + [ branch + "A" + [("x", Nothing), ("y", Nothing), ("z", Nothing)] + (lvar "y") + , branch "B" [] emptyHole + ] + sequence + [ ASTDef "def" x <$> tEmptyHole + ] + ) + [AddConField "T" "A" 2 $ TCon () "Int"] + $ expectSuccess $ \_ prog' -> do + def <- findDef "def" prog' + forgetIDs (astDefExpr def) + @?= forgetIDs + ( fst . create $ + case_ + (emptyHole `ann` (tcon "T" `tapp` tEmptyHole `tapp` tEmptyHole)) + [ branch + "A" + [("x", Nothing), ("y", Nothing), ("a19", Nothing), ("z", Nothing)] + (lvar "y") + , branch "B" [] emptyHole + ] + ) + -- * Utilities findGlobalByName :: Prog -> GVarName -> Maybe Def @@ -752,6 +1070,31 @@ defaultFullProg = do . over (#progModule % #moduleDefs) ((DefPrim <$> m) <>) $ p +findTypeDef :: TyConName -> Prog -> IO ASTTypeDef +findTypeDef d p = maybe (assertFailure "couldn't find typedef") pure $ (typeDefAST <=< Map.lookup d) $ p ^. (#progModule % #moduleTypes) + +findDef :: GVarName -> Prog -> IO ASTDef +findDef d p = maybe (assertFailure "couldn't find def") pure $ (defAST <=< Map.lookup d) $ p ^. (#progModule % #moduleDefs) + +-- We use the same type definition for all tests related to editing type definitions +defaultProgEditableTypeDefs :: MonadFresh ID f => f [ASTDef] -> f Prog +defaultProgEditableTypeDefs ds = do + p <- defaultEmptyProg + ds' <- ds + let tds = + [ TypeDefAST + ASTTypeDef + { astTypeDefName = "T" + , astTypeDefParameters = [("a", KType), ("b", KType)] + , astTypeDefConstructors = [ValCon "A" (replicate 3 $ TCon () "Bool"), ValCon "B" [TVar () "b"]] + , astTypeDefNameHints = [] + } + ] + pure $ + p + & (#progModule % #moduleTypes) %~ ((mkTypeDefMap tds <> defaultTypeDefs) <>) + & (#progModule % #moduleDefs) %~ (Map.fromList ((\d -> (astDefName d, DefAST d)) <$> ds') <>) + unit_good_defaultFullProg :: Assertion unit_good_defaultFullProg = checkProgWellFormed defaultFullProg diff --git a/primer/test/Tests/Eval.hs b/primer/test/Tests/Eval.hs index 3fc66ec07..1fd672d32 100644 --- a/primer/test/Tests/Eval.hs +++ b/primer/test/Tests/Eval.hs @@ -6,7 +6,7 @@ import Foreword import qualified Data.Map.Strict as Map import qualified Data.Set as Set -import Optics (over, (^.)) +import Optics ((^.)) import Primer.App ( App (appIdCounter), EvalReq (EvalReq, evalReqExpr, evalReqRedex), @@ -21,18 +21,14 @@ import Primer.Core ( ASTDef (..), Def (..), Expr, - Expr', ID (ID), Type, - Type', TypeDef (TypeDefAST), getID, - _exprMeta, - _exprTypeMeta, _id, - _typeMeta, ) import Primer.Core.DSL +import Primer.Core.Utils (forgetIDs, forgetTypeIDs) import Primer.Eval ( ApplyPrimFunDetail (..), BetaReductionDetail (..), @@ -51,6 +47,7 @@ import Primer.Eval ( tryReduceType, ) import Primer.Module (Module (Module, moduleDefs, moduleTypes)) +import Primer.Typecheck (mkTypeDefMap) import Primer.Zipper (target) import Test.Tasty.HUnit (Assertion, assertBool, assertFailure, (@?=)) import TestM (evalTestM) @@ -847,7 +844,7 @@ unit_eval_modules_scrutinize_imported_type = where m = Module - { moduleTypes = [TypeDefAST boolDef] + { moduleTypes = mkTypeDefMap [TypeDefAST boolDef] , moduleDefs = mempty } @@ -856,16 +853,8 @@ unit_eval_modules_scrutinize_imported_type = -- | Like '@?=' but specifically for expressions. -- Ignores IDs and metadata. (~=) :: Expr -> Expr -> Assertion -x ~= y = clearMeta x @?= clearMeta y - where - -- Clear all metadata in the given expression - clearMeta :: Expr -> Expr' () () - clearMeta = over _exprMeta (const ()) . over _exprTypeMeta (const ()) +x ~= y = forgetIDs x @?= forgetIDs y -- | Like '~=' but for types. (~~=) :: Type -> Type -> Assertion -x ~~= y = clearMeta x @?= clearMeta y - where - -- Clear all metadata in the given type - clearMeta :: Type -> Type' () - clearMeta = over _typeMeta (const ()) +x ~~= y = forgetTypeIDs x @?= forgetTypeIDs y diff --git a/primer/test/Tests/EvalFull.hs b/primer/test/Tests/EvalFull.hs index 22198d039..c5c9aa8d1 100644 --- a/primer/test/Tests/EvalFull.hs +++ b/primer/test/Tests/EvalFull.hs @@ -166,10 +166,10 @@ unit_8 = expect <- mkList (tcon "Bool") (take n $ cycle [con "True", con "False"]) `ann` (tcon "List" `tapp` tcon "Bool") pure (globs, expr, expect) in do - case evalFullTest maxID (mkTypeDefMap defaultTypeDefs) (M.fromList globals) 500 Syn e of + case evalFullTest maxID defaultTypeDefs (M.fromList globals) 500 Syn e of Left (TimedOut _) -> pure () x -> assertFailure $ show x - let s = evalFullTest maxID (mkTypeDefMap defaultTypeDefs) (M.fromList globals) 1000 Syn e + let s = evalFullTest maxID defaultTypeDefs (M.fromList globals) 1000 Syn e distinctIDs s s <~==> Right expected @@ -203,10 +203,10 @@ unit_9 = expect <- mkList (tcon "Bool") (take n $ cycle [con "True", con "False"]) `ann` (tcon "List" `tapp` tcon "Bool") pure (globs, expr, expect) in do - case evalFullTest maxID (mkTypeDefMap defaultTypeDefs) (M.fromList globals) 500 Syn e of + case evalFullTest maxID defaultTypeDefs (M.fromList globals) 500 Syn e of Left (TimedOut _) -> pure () x -> assertFailure $ show x - let s = evalFullTest maxID (mkTypeDefMap defaultTypeDefs) (M.fromList globals) 1000 Syn e + let s = evalFullTest maxID defaultTypeDefs (M.fromList globals) 1000 Syn e distinctIDs s s <~==> Right expected @@ -232,8 +232,8 @@ unit_10 = expect <- con "True" pure (annCase, noannCase, expect) in do - let s' = evalFullTest maxID (mkTypeDefMap defaultTypeDefs) mempty 2 Syn s - t' = evalFullTest maxID (mkTypeDefMap defaultTypeDefs) mempty 2 Syn t + let s' = evalFullTest maxID defaultTypeDefs mempty 2 Syn s + t' = evalFullTest maxID defaultTypeDefs mempty 2 Syn t distinctIDs s' s' <~==> Right expected distinctIDs t' @@ -264,10 +264,10 @@ unit_11 = `ann` (tcon "Pair" `tapp` tcon "Bool" `tapp` tcon "Nat") pure (globs, expr, expect) in do - case evalFullTest maxID (mkTypeDefMap defaultTypeDefs) (M.fromList globals) 10 Syn e of + case evalFullTest maxID defaultTypeDefs (M.fromList globals) 10 Syn e of Left (TimedOut _) -> pure () x -> assertFailure $ show x - let s = evalFullTest maxID (mkTypeDefMap defaultTypeDefs) (M.fromList globals) 20 Syn e + let s = evalFullTest maxID defaultTypeDefs (M.fromList globals) 20 Syn e distinctIDs s s <~==> Right expected @@ -289,7 +289,7 @@ unit_12 = expect <- con "True" `ann` tcon "Bool" pure (expr, expect) in do - let s = evalFullTest maxID (mkTypeDefMap defaultTypeDefs) mempty 15 Syn e + let s = evalFullTest maxID defaultTypeDefs mempty 15 Syn e distinctIDs s s <~==> Right expected @@ -300,7 +300,7 @@ unit_13 = expect <- (con "C" `app` con "Zero" `app` con "True" `app` con "Zero") `ann` tcon "Bool" pure (expr, expect) in do - let s = evalFullTest maxID (mkTypeDefMap defaultTypeDefs) mempty 15 Syn e + let s = evalFullTest maxID defaultTypeDefs mempty 15 Syn e distinctIDs s s <~==> Right expected @@ -311,7 +311,7 @@ unit_14 = expect <- con "Zero" `ann` tcon "Nat" pure (expr, expect) in do - let s = evalFullTest maxID (mkTypeDefMap defaultTypeDefs) mempty 15 Syn e + let s = evalFullTest maxID defaultTypeDefs mempty 15 Syn e distinctIDs s s <~==> Right expected @@ -338,19 +338,19 @@ unit_15 = e5 <- lam y' $ c "y" y' pure (e0, [e0, e1, e2, e3, e4, e5], e5) in do - let si = map (\i -> evalFullTest maxID (mkTypeDefMap defaultTypeDefs) mempty i Syn expr) [0 .. fromIntegral $ length steps - 1] + let si = map (\i -> evalFullTest maxID defaultTypeDefs mempty i Syn expr) [0 .. fromIntegral $ length steps - 1] f s e = do distinctIDs s s <~==> Left (TimedOut e) zipWithM_ f si steps - let s = evalFullTest maxID (mkTypeDefMap defaultTypeDefs) mempty (fromIntegral $ length steps) Syn expr + let s = evalFullTest maxID defaultTypeDefs mempty (fromIntegral $ length steps) Syn expr distinctIDs s s <~==> Right expected unit_hole_ann_case :: Assertion unit_hole_ann_case = let (tm, maxID) = create $ hole $ ann (case_ emptyHole []) (tcon "Bool") - in evalFullTest maxID (mkTypeDefMap defaultTypeDefs) mempty 1 Chk tm @?= Right tm + in evalFullTest maxID defaultTypeDefs mempty 1 Chk tm @?= Right tm -- TODO: examples with holes @@ -509,7 +509,7 @@ hprop_prim_hex_nat = withTests 20 . property $ do <*> con "Nothing" `aPP` tcon "Char" <*> pure (DefPrim <$> globals) - s = evalFullTest maxID (mkTypeDefMap defaultTypeDefs) gs 7 Syn e + s = evalFullTest maxID defaultTypeDefs gs 7 Syn e set _ids' 0 s === set _ids' 0 (Right r) unit_prim_char_eq_1 :: Assertion @@ -855,7 +855,7 @@ unit_prim_ann = `app` (char 'a' `ann` tcon "Char") <*> char 'A' <*> pure (DefPrim <$> globals) - s = evalFullTest maxID (mkTypeDefMap defaultTypeDefs) gs 2 Syn e + s = evalFullTest maxID defaultTypeDefs gs 2 Syn e in do distinctIDs s s <~==> Right r @@ -884,7 +884,7 @@ unit_prim_partial_map = ] `ann` (tcon "List" `tapp` tcon "Char") <*> pure (M.singleton (defName map_) map_ <> (DefPrim <$> globals)) - s = evalFullTest maxID (mkTypeDefMap defaultTypeDefs) gs 65 Syn e + s = evalFullTest maxID defaultTypeDefs gs 65 Syn e in do distinctIDs s s <~==> Right r @@ -945,7 +945,7 @@ unit_eval_full_modules_scrutinize_imported_type = where m = Module - { moduleTypes = [TypeDefAST boolDef] + { moduleTypes = mkTypeDefMap [TypeDefAST boolDef] , moduleDefs = mempty } diff --git a/primer/test/Tests/Primitives.hs b/primer/test/Tests/Primitives.hs index a706adc27..4d72378fc 100644 --- a/primer/test/Tests/Primitives.hs +++ b/primer/test/Tests/Primitives.hs @@ -29,6 +29,7 @@ import Primer.Typecheck ( buildTypingContext, checkKind, checkValidContext, + mkTypeDefMap, synth, ) @@ -72,5 +73,5 @@ unit_prim_con_scope_ast = do , astTypeDefNameHints = mempty } - cxt = buildTypingContext [charASTDef] mempty NoSmartHoles + cxt = buildTypingContext (mkTypeDefMap [charASTDef]) mempty NoSmartHoles test = runTypecheckTestMFromIn 0 cxt diff --git a/primer/test/Tests/Serialization.hs b/primer/test/Tests/Serialization.hs index 060001176..2a71f8145 100644 --- a/primer/test/Tests/Serialization.hs +++ b/primer/test/Tests/Serialization.hs @@ -44,7 +44,7 @@ import Primer.Core ( ) import Primer.Module (Module (Module, moduleDefs, moduleTypes)) import Primer.Name (unsafeMkName) -import Primer.Typecheck (SmartHoles (SmartHoles)) +import Primer.Typecheck (SmartHoles (SmartHoles), mkTypeDefMap) import System.FilePath (takeBaseName) import Test.Tasty import Test.Tasty.Golden @@ -112,7 +112,7 @@ fixtures = { progImports = mempty , progModule = Module - { moduleTypes = [typeDef] + { moduleTypes = mkTypeDefMap [typeDef] , moduleDefs = Map.singleton (astDefName def) (DefAST def) } , progSelection = Just selection diff --git a/primer/test/Tests/Typecheck.hs b/primer/test/Tests/Typecheck.hs index 577572b88..c5c1a6807 100644 --- a/primer/test/Tests/Typecheck.hs +++ b/primer/test/Tests/Typecheck.hs @@ -42,6 +42,7 @@ import Primer.Core ( Meta (..), PrimDef (PrimDef, primDefName, primDefType), TmVarRef (LocalVarRef), + TyConName, Type, Type' (TApp, TCon, TForall, TFun, TVar), TypeCache (..), @@ -70,6 +71,7 @@ import Primer.Typecheck ( checkEverything, decomposeTAppCon, mkTAppCon, + mkTypeDefMap, synth, synthKind, ) @@ -512,7 +514,7 @@ unit_good_maybeT = case runTypecheckTestM NoSmartHoles $ NoSmartHoles CheckEverything { trusted = [progModule newProg] - , toCheck = [Module [TypeDefAST maybeTDef] mempty] + , toCheck = [Module (mkTypeDefMap [TypeDefAST maybeTDef]) mempty] } of Left err -> assertFailure $ show err Right _ -> pure () @@ -525,7 +527,7 @@ unit_bad_prim_map = case runTypecheckTestM NoSmartHoles $ do NoSmartHoles CheckEverything { trusted = [progModule newProg] - , toCheck = [Module [] $ Map.singleton "foo" $ DefPrim foo] + , toCheck = [Module mempty $ Map.singleton "foo" $ DefPrim foo] } of Left err -> err @?= InternalError "Inconsistant names in moduleDefs map" Right _ -> assertFailure "Expected failure but succeeded" @@ -538,7 +540,7 @@ unit_bad_prim_type = case runTypecheckTestM NoSmartHoles $ do NoSmartHoles CheckEverything { trusted = [progModule newProg] - , toCheck = [Module [] $ Map.singleton "foo" $ DefPrim foo] + , toCheck = [Module mempty $ Map.singleton "foo" $ DefPrim foo] } of Left err -> err @?= UnknownTypeConstructor "NonExistant" Right _ -> assertFailure "Expected failure but succeeded" @@ -634,8 +636,8 @@ runTypecheckTestMWithPrims sh = where (defs, n) = create $ withPrimDefs $ \m -> pure $ DefPrim <$> m -testingTypeDefs :: [TypeDef] -testingTypeDefs = TypeDefAST maybeTDef : defaultTypeDefs +testingTypeDefs :: Map TyConName TypeDef +testingTypeDefs = mkTypeDefMap [TypeDefAST maybeTDef] <> defaultTypeDefs maybeTDef :: ASTTypeDef maybeTDef = diff --git a/primer/test/outputs/serialization/edit_response_2.json b/primer/test/outputs/serialization/edit_response_2.json index b25baf3ea..9936c2986 100644 --- a/primer/test/outputs/serialization/edit_response_2.json +++ b/primer/test/outputs/serialization/edit_response_2.json @@ -51,8 +51,8 @@ "tag": "DefAST" } }, - "moduleTypes": [ - { + "moduleTypes": { + "T": { "contents": { "astTypeDefConstructors": [ { @@ -115,7 +115,7 @@ }, "tag": "TypeDefAST" } - ] + } }, "progSelection": { "selectedDef": "main", diff --git a/primer/test/outputs/serialization/prog.json b/primer/test/outputs/serialization/prog.json index 146091dab..6ee4f26c5 100644 --- a/primer/test/outputs/serialization/prog.json +++ b/primer/test/outputs/serialization/prog.json @@ -50,8 +50,8 @@ "tag": "DefAST" } }, - "moduleTypes": [ - { + "moduleTypes": { + "T": { "contents": { "astTypeDefConstructors": [ { @@ -114,7 +114,7 @@ }, "tag": "TypeDefAST" } - ] + } }, "progSelection": { "selectedDef": "main",