diff --git a/packages/reflex-base/src/reflex_base/compiler/templates.py b/packages/reflex-base/src/reflex_base/compiler/templates.py index b4056227987..3c525e99f4e 100644 --- a/packages/reflex-base/src/reflex_base/compiler/templates.py +++ b/packages/reflex-base/src/reflex_base/compiler/templates.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from reflex.compiler.utils import _ImportDict - from reflex_base.components.component import Component, StatefulComponent + from reflex_base.components.component import Component def _sort_hooks( @@ -348,7 +348,12 @@ def context_template( export const initialState = {"{}" if not initial_state else json_dumps(initial_state)} export const defaultColorMode = {default_color_mode} -export const ColorModeContext = createContext(null); +export const ColorModeContext = createContext({{ + colorMode: defaultColorMode, + resolvedColorMode: defaultColorMode === "dark" ? "dark" : "light", + toggleColorMode: () => {{}}, + setColorMode: () => {{}}, +}}); export const UploadFilesContext = createContext(null); export const DispatchContext = createContext(null); export const StateContexts = {{{state_contexts_str}}}; @@ -417,7 +422,7 @@ def context_template( }}""" -def component_template(component: Component | StatefulComponent): +def component_template(component: Component): """Template to render a component tag. Args: @@ -618,24 +623,23 @@ def vite_config_template( }}));""" -def stateful_component_template( - tag_name: str, memo_trigger_hooks: list[str], component: Component, export: bool -): - """Template for stateful component. +def dynamic_component_template( + tag_name: str, component: Component, export: bool +) -> str: + """Template for a dynamic SSR component function declaration. Args: tag_name: The tag name for the component. - memo_trigger_hooks: The memo trigger hooks for the component. component: The component to render. export: Whether to export the component. Returns: - Rendered stateful component code as string. + Rendered dynamic component code as string. """ all_hooks = component._get_all_hooks() return f""" {"export " if export else ""}function {tag_name} () {{ - {_render_hooks(all_hooks, memo_trigger_hooks)} + {_render_hooks(all_hooks)} return ( {_RenderUtils.render(component.render())} ) @@ -643,15 +647,17 @@ def stateful_component_template( """ -def stateful_components_template(imports: list[_ImportDict], memoized_code: str) -> str: - """Template for stateful components. +def dynamic_components_module_template( + imports: list[_ImportDict], memoized_code: str +) -> str: + """Template for a dynamic-SSR components module. Args: imports: List of import statements. - memoized_code: Memoized code for stateful components. + memoized_code: Code for the module body. Returns: - Rendered stateful components code as string. + Rendered module code as string. """ imports_str = "\n".join([_RenderUtils.get_import(imp) for imp in imports]) return f"{imports_str}\n{memoized_code}" @@ -709,6 +715,83 @@ def memo_components_template( {components_code}""" +def memo_single_component_template( + imports: list[_ImportDict], + component: dict[str, Any], + dynamic_imports: Iterable[str], + custom_codes: Iterable[str], +) -> str: + """Template for a single memoized component in its own module. + + Args: + imports: List of import statements for this memo only. + component: The single component definition to render. + dynamic_imports: Dynamic import statements scoped to this memo. + custom_codes: Custom code snippets scoped to this memo. + + Returns: + The rendered standalone memo module code. + """ + imports_str = "\n".join([_RenderUtils.get_import(imp) for imp in imports]) + dynamic_imports_str = "\n".join(dynamic_imports) + custom_code_str = "\n".join(custom_codes) + + component_code = f""" +export const {component["name"]} = memo(({component["signature"]}) => {{ + {_render_hooks(component.get("hooks", {}))} + return( + {_RenderUtils.render(component["render"])} + ) +}}); +""" + + return f""" +{imports_str} + +{dynamic_imports_str} + +{custom_code_str} + +{component_code}""" + + +def memo_single_function_template( + imports: list[_ImportDict], + function: dict[str, Any], +) -> str: + """Template for a single function memo in its own module. + + Args: + imports: List of import statements for this memo only. + function: The single function memo definition. + + Returns: + The rendered standalone function memo module code. + """ + imports_str = "\n".join([_RenderUtils.get_import(imp) for imp in imports]) + return f""" +{imports_str} + +export const {function["name"]} = {function["function"]}; +""" + + +def memo_index_template(reexports: Iterable[tuple[str, str]]) -> str: + """Template for the memo index module that re-exports every memo file. + + Args: + reexports: Iterable of ``(export_name, relative_module_specifier)``. + + Returns: + The rendered index module code. + """ + lines = [ + f'export {{ {export_name} }} from "{specifier}";' + for export_name, specifier in reexports + ] + return "\n".join(lines) + "\n" + + def styles_template(stylesheets: list[str]) -> str: """Template for styles.css. diff --git a/packages/reflex-base/src/reflex_base/components/component.py b/packages/reflex-base/src/reflex_base/components/component.py index 8f47f447c7d..79b562ddc78 100644 --- a/packages/reflex-base/src/reflex_base/components/component.py +++ b/packages/reflex-base/src/reflex_base/components/component.py @@ -3,11 +3,11 @@ from __future__ import annotations import contextlib -import copy import dataclasses import enum import functools import inspect +import operator import typing from abc import ABC, ABCMeta, abstractmethod from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence @@ -22,19 +22,10 @@ from reflex_base import constants from reflex_base.breakpoints import Breakpoints -from reflex_base.compiler.templates import stateful_component_template from reflex_base.components.dynamic import load_dynamic_serializer from reflex_base.components.field import BaseField, FieldBasedMeta from reflex_base.components.tags import Tag -from reflex_base.constants import ( - Dirs, - EventTriggers, - Hooks, - Imports, - MemoizationDisposition, - MemoizationMode, - PageNames, -) +from reflex_base.constants import Dirs, EventTriggers, Hooks, Imports, MemoizationMode from reflex_base.constants.compiler import SpecialAttributes from reflex_base.constants.state import CAMEL_CASE_MEMO_MARKER from reflex_base.event import ( @@ -321,6 +312,31 @@ def set(self, **kwargs): setattr(self, key, value) return self + def __copy__(self) -> BaseComponent: + """Return a shallow copy suitable for compile-time mutation. + + Bypasses ``copy.copy``'s generic ``__reduce_ex__`` dispatch. Nested + mutable containers (``children``, ``style``, ``event_triggers``) are + shared with the original until the caller explicitly rebinds them. + Render-path caches populated on the original are dropped so the clone + recomputes against its (potentially rebound) fields. + + Returns: + A new instance of the same class with ``__dict__`` shallow-copied. + """ + new = self.__class__.__new__(self.__class__) + new_dict = vars(new) + new_dict.update(vars(self)) + for attr in ( + "_cached_render_result", + "_vars_cache", + "_imports_cache", + "_hooks_internal_cache", + "_get_component_prop_property", + ): + new_dict.pop(attr, None) + return new + def __eq__(self, value: Any) -> bool: """Check if the component is equal to another value. @@ -502,14 +518,65 @@ def _hash_str(value: str) -> str: return md5(f'"{value}"'.encode(), usedforsecurity=False).hexdigest() -def _hash_sequence(value: Sequence) -> str: - return _hash_str(str([_deterministic_hash(v) for v in value])) +def _update_deterministic_hash(hasher: Any, value: object) -> None: + """Feed ``value`` into ``hasher`` using a self-delimiting, type-tagged encoding. + Each branch writes a distinct type tag plus length-prefixed payload, which + keeps the encoding injective without building intermediate strings — the + nested ``str([...])`` approach this replaces was the dominant cost of + ``_deterministic_hash`` (~4x speedup on synthetic, ~2x on real renders). -def _hash_dict(value: dict) -> str: - return _hash_sequence( - sorted([(k, _deterministic_hash(v)) for k, v in value.items()]) - ) + Args: + hasher: A ``hashlib`` hasher (must accept ``.update(bytes)``). + value: The value to fold into the hasher. + + Raises: + TypeError: If the value is not hashable. + """ + if value is None: + hasher.update(b"N") + elif isinstance(value, bool): + hasher.update(b"T" if value else b"F") + elif isinstance(value, (int, float, enum.Enum)): + hasher.update(b"n") + hasher.update(str(value).encode()) + elif isinstance(value, str): + encoded = value.encode() + hasher.update(b"s") + hasher.update(len(encoded).to_bytes(8, "little")) + hasher.update(encoded) + elif isinstance(value, dict): + items = sorted(value.items(), key=operator.itemgetter(0)) + hasher.update(b"d") + hasher.update(len(items).to_bytes(8, "little")) + for k, v in items: + _update_deterministic_hash(hasher, k) + _update_deterministic_hash(hasher, v) + elif isinstance(value, (tuple, list)): + hasher.update(b"l") + hasher.update(len(value).to_bytes(8, "little")) + for item in value: + _update_deterministic_hash(hasher, item) + elif isinstance(value, Var): + hasher.update(b"v") + _update_deterministic_hash(hasher, value._js_expr) + _update_deterministic_hash(hasher, value._get_all_var_data()) + elif dataclasses.is_dataclass(value): + fields = dataclasses.fields(value) + hasher.update(b"D") + hasher.update(len(fields).to_bytes(8, "little")) + for field in fields: + hasher.update(field.name.encode()) + _update_deterministic_hash(hasher, getattr(value, field.name)) + elif isinstance(value, BaseComponent): + hasher.update(b"C") + _update_deterministic_hash(hasher, value.render()) + else: + msg = ( + f"Cannot hash value `{value}` of type `{type(value).__name__}`. " + "Only BaseComponent, Var, VarData, dict, str, tuple, and enum.Enum are supported." + ) + raise TypeError(msg) def _deterministic_hash(value: object) -> str: @@ -524,36 +591,9 @@ def _deterministic_hash(value: object) -> str: Raises: TypeError: If the value is not hashable. """ - if value is None: - # Hash None as a special case. - return "None" - if isinstance(value, (int, float, enum.Enum)): - # Hash numbers and booleans directly. - return str(value) - if isinstance(value, str): - return _hash_str(value) - if isinstance(value, dict): - return _hash_dict(value) - if isinstance(value, (tuple, list)): - # Hash tuples by hashing each element. - return _hash_sequence(value) - if isinstance(value, Var): - return _hash_str( - str((value._js_expr, _deterministic_hash(value._get_all_var_data()))) - ) - if dataclasses.is_dataclass(value): - return _hash_dict({ - k.name: getattr(value, k.name) for k in dataclasses.fields(value) - }) - if isinstance(value, BaseComponent): - # If the value is a component, hash its rendered code. - return _hash_dict(value.render()) - - msg = ( - f"Cannot hash value `{value}` of type `{type(value).__name__}`. " - "Only BaseComponent, Var, VarData, dict, str, tuple, and enum.Enum are supported." - ) - raise TypeError(msg) + hasher = md5(usedforsecurity=False) + _update_deterministic_hash(hasher, value) + return hasher.hexdigest() @dataclasses.dataclass(kw_only=True, frozen=True, slots=True) @@ -1300,7 +1340,7 @@ def _add_style_recursive( # Recursively add style to the children. for child in self.children: - # Skip BaseComponent and StatefulComponent children. + # Skip non-Component children. if not isinstance(child, Component): continue child._add_style_recursive(style, theme) @@ -1327,6 +1367,10 @@ def render(self) -> dict: Returns: The dictionary for template of component. """ + try: + return self._cached_render_result + except AttributeError: + pass tag = self._render() rendered_dict = dict( tag.set( @@ -1334,6 +1378,7 @@ def render(self) -> dict: ) ) self._replace_prop_names(rendered_dict) + self._cached_render_result = rendered_dict return rendered_dict def _replace_prop_names(self, rendered_dict: dict) -> None: @@ -1457,11 +1502,14 @@ def _get_vars( Yields: Each var referenced by the component (props, styles, event handlers). """ + if not include_children and ignore_ids is None: + cached = self.__dict__.get("_vars_cache") + if cached is not None: + yield from cached + return + ignore_ids = ignore_ids or set() - vars: list[Var] | None = getattr(self, "__vars", None) - if vars is not None: - yield from vars - vars = self.__vars = [] + vars: list[Var] = [] # Get Vars associated with event trigger arguments. for _, event_vars in self._get_vars_from_event_triggers(self.event_triggers): vars.extend(event_vars) @@ -1500,17 +1548,19 @@ def _get_vars( if var._get_all_var_data() is not None: vars.append(var) - # Get Vars associated with children. if include_children: + yield from vars for child in self.children: if not isinstance(child, Component) or id(child) in ignore_ids: continue ignore_ids.add(id(child)) - child_vars = child._get_vars( + yield from child._get_vars( include_children=include_children, ignore_ids=ignore_ids ) - vars.extend(child_vars) + return + # Freeze and cache the default-args result. + self._vars_cache = tuple(vars) yield from vars def _event_trigger_values_use_state(self) -> bool: @@ -1555,6 +1605,7 @@ def _iter_parent_classes_names(cls) -> Iterator[str]: yield clz.__name__ @classmethod + @functools.cache def _iter_parent_classes_with_method(cls, method: str) -> Sequence[type[Component]]: """Iterate through parent classes that define a given method. @@ -1581,7 +1632,7 @@ def _iter_parent_classes_with_method(cls, method: str) -> Sequence[type[Componen continue seen_methods.add(method_func) clzs.append(clz) - return clzs + return tuple(clzs) def _get_custom_code(self) -> str | None: """Get custom code for the component. @@ -1704,6 +1755,10 @@ def _get_imports(self) -> ParsedImportDict: Returns: The imports needed by the component. """ + cached = self.__dict__.get("_imports_cache") + if cached is not None: + return cached + imports_ = ( {self.library: [self.import_var]} if self.library is not None and self.tag is not None @@ -1731,7 +1786,7 @@ def _get_imports(self) -> ParsedImportDict: imports.parse_imports(item) for item in list_of_import_dict ]) - return imports.merge_parsed_imports( + result = imports.merge_parsed_imports( self._get_dependencies_imports(), self._get_hooks_imports(), imports_, @@ -1739,6 +1794,8 @@ def _get_imports(self) -> ParsedImportDict: *var_imports, *added_import_dicts, ) + self._imports_cache = result + return result def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict: """Get all the libraries and fields that are used by the component and its children. @@ -1840,15 +1897,21 @@ def _get_hooks_internal(self) -> dict[str, VarData | None]: Returns: The internally managed hooks. """ - return { + cached = self.__dict__.get("_hooks_internal_cache") + if cached is not None: + return cached + + result = { + **self._get_events_hooks(), **{ str(hook): VarData(position=Hooks.HookPosition.INTERNAL) for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()] if hook is not None }, **self._get_vars_hooks(), - **self._get_events_hooks(), } + self._hooks_internal_cache = result + return result def _get_added_hooks(self) -> dict[str, VarData | None]: """Get the hooks added via `add_hooks` method. @@ -2005,7 +2068,7 @@ def _get_all_app_wrap_components( # Add the app wrap components for the children. for child in self.children: child_id = id(child) - # Skip BaseComponent and StatefulComponent children. + # Skip non-Component children. if not isinstance(child, Component) or child_id in ignore_ids: continue ignore_ids.add(child_id) @@ -2382,440 +2445,6 @@ def _get_dynamic_imports(self) -> str: ) -class StatefulComponent(BaseComponent): - """A component that depends on state and is rendered outside of the page component. - - If a StatefulComponent is used in multiple pages, it will be rendered to a common file and - imported into each page that uses it. - - A stateful component has a tag name that includes a hash of the code that it renders - to. This tag name refers to the specific component with the specific props that it - was created with. - """ - - # Reference to the original component that was memoized into this component. - component: Component = field( - default_factory=Component, is_javascript_property=False - ) - - references: int = field( - doc="How many times this component is referenced in the app.", - default=0, - is_javascript_property=False, - ) - - rendered_as_shared: bool = field( - doc="Whether the component has already been rendered to a shared file.", - default=False, - is_javascript_property=False, - ) - - memo_trigger_hooks: list[str] = field( - default_factory=list, is_javascript_property=False - ) - - @classmethod - def create(cls, component: Component) -> StatefulComponent | None: - """Create a stateful component from a component. - - Args: - component: The component to memoize. - - Returns: - The stateful component or None if the component should not be memoized. - """ - from reflex_components_core.core.foreach import Foreach - - from reflex_base.registry import RegistrationContext - - if component._memoization_mode.disposition == MemoizationDisposition.NEVER: - # Never memoize this component. - return None - - if component.tag is None: - # Only memoize components with a tag. - return None - - # If _var_data is found in this component, it is a candidate for auto-memoization. - should_memoize = False - - # If the component requests to be memoized, then ignore other checks. - if component._memoization_mode.disposition == MemoizationDisposition.ALWAYS: - should_memoize = True - - if not should_memoize: - # Determine if any Vars have associated data. - for prop_var in component._get_vars(include_children=True): - if prop_var._get_all_var_data(): - should_memoize = True - break - - if not should_memoize: - # Check for special-cases in child components. - for child in component.children: - # Skip BaseComponent and StatefulComponent children. - if not isinstance(child, Component): - continue - # Always consider Foreach something that must be memoized by the parent. - if isinstance(child, Foreach): - should_memoize = True - break - child = cls._child_var(child) - if isinstance(child, Var) and child._get_all_var_data(): - should_memoize = True - break - - if should_memoize or component.event_triggers: - # Render the component to determine tag+hash based on component code. - tag_name = cls._get_tag_name(component) - if tag_name is None: - return None - - # Look up the tag in the cache - ctx = RegistrationContext.get() - stateful_component = ctx.tag_to_stateful_component.get(tag_name) - if stateful_component is None: - memo_trigger_hooks = cls._fix_event_triggers(component) - # Set the stateful component in the cache for the given tag. - stateful_component = ctx.tag_to_stateful_component.setdefault( - tag_name, - cls( - children=component.children, - component=component, - tag=tag_name, - memo_trigger_hooks=memo_trigger_hooks, - ), - ) - # Bump the reference count -- multiple pages referencing the same component - # will result in writing it to a common file. - stateful_component.references += 1 - return stateful_component - - # Return None to indicate this component should not be memoized. - return None - - @staticmethod - def _child_var(child: Component) -> Var | Component: - """Get the Var from a child component. - - This method is used for special cases when the StatefulComponent should actually - wrap the parent component of the child instead of recursing into the children - and memoizing them independently. - - Args: - child: The child component. - - Returns: - The Var from the child component or the child itself (for regular cases). - """ - from reflex_components_core.base.bare import Bare - from reflex_components_core.core.cond import Cond - from reflex_components_core.core.foreach import Foreach - from reflex_components_core.core.match import Match - - if isinstance(child, Bare): - return child.contents - if isinstance(child, Cond): - return child.cond - if isinstance(child, Foreach): - return child.iterable - if isinstance(child, Match): - return child.cond - return child - - @classmethod - def _get_tag_name(cls, component: Component) -> str | None: - """Get the tag based on rendering the given component. - - Args: - component: The component to render. - - Returns: - The tag for the stateful component. - """ - # Get the render dict for the component. - rendered_code = component.render() - if not rendered_code: - # Never memoize non-visual components. - return None - - # Compute the hash based on the rendered code. - code_hash = _hash_str(_deterministic_hash(rendered_code)) - - # Format the tag name including the hash. - return format.format_state_name( - f"{component.tag or 'Comp'}_{code_hash}" - ).capitalize() - - def _render_stateful_code( - self, - export: bool = False, - ) -> str: - if not self.tag: - return "" - # Render the code for this component and hooks. - return stateful_component_template( - tag_name=self.tag, - memo_trigger_hooks=self.memo_trigger_hooks, - component=self.component, - export=export, - ) - - @classmethod - def _fix_event_triggers( - cls, - component: Component, - ) -> list[str]: - """Render the code for a stateful component. - - Args: - component: The component to render. - - Returns: - The memoized event trigger hooks for the component. - """ - # Memoize event triggers useCallback to avoid unnecessary re-renders. - memo_event_triggers = tuple(cls._get_memoized_event_triggers(component).items()) - - # Trigger hooks stored separately to write after the normal hooks (see stateful_component.js.jinja2) - memo_trigger_hooks: list[str] = [] - - if memo_event_triggers: - # Copy the component to avoid mutating the original. - component = copy.copy(component) - - for event_trigger, ( - memo_trigger, - memo_trigger_hook, - ) in memo_event_triggers: - # Replace the event trigger with the memoized version. - memo_trigger_hooks.append(memo_trigger_hook) - component.event_triggers[event_trigger] = memo_trigger - - return memo_trigger_hooks - - @staticmethod - def _get_hook_deps(hook: str) -> list[str]: - """Extract var deps from a hook. - - Args: - hook: The hook line to extract deps from. - - Returns: - A list of var names created by the hook declaration. - """ - # Ensure that the hook is a var declaration. - var_decl = hook.partition("=")[0].strip() - if not any(var_decl.startswith(kw) for kw in ["const ", "let ", "var "]): - return [] - - # Extract the var name from the declaration. - _, _, var_name = var_decl.partition(" ") - var_name = var_name.strip() - - # Break up array and object destructuring if used. - if var_name.startswith(("[", "{")): - return [ - v.strip().replace("...", "") for v in var_name.strip("[]{}").split(",") - ] - return [var_name] - - @staticmethod - def _get_deps_from_event_trigger( - event: EventChain | EventSpec | Var, - ) -> dict[str, None]: - """Get the dependencies accessed by event triggers. - - Args: - event: The event trigger to extract deps from. - - Returns: - The dependencies accessed by the event triggers. - """ - events: list = [event] - deps = {} - - if isinstance(event, EventChain): - events.extend(event.events) - - for ev in events: - if isinstance(ev, EventSpec): - for arg in ev.args: - for a in arg: - var_datas = VarData.merge(a._get_all_var_data()) - if var_datas and var_datas.deps is not None: - deps |= {str(dep): None for dep in var_datas.deps} - return deps - - @classmethod - def _get_memoized_event_triggers( - cls, - component: Component, - ) -> dict[str, tuple[Var, str]]: - """Memoize event handler functions with useCallback to avoid unnecessary re-renders. - - Args: - component: The component with events to memoize. - - Returns: - A dict of event trigger name to a tuple of the memoized event trigger Var and - the hook code that memoizes the event handler. - """ - trigger_memo = {} - for event_trigger, event_args in component._get_vars_from_event_triggers( - component.event_triggers - ): - if event_trigger in { - EventTriggers.ON_MOUNT, - EventTriggers.ON_UNMOUNT, - EventTriggers.ON_SUBMIT, - }: - # Do not memoize lifecycle or submit events. - continue - - # Get the actual EventSpec and render it. - event = component.event_triggers[event_trigger] - rendered_chain = str(LiteralVar.create(event)) - - # Hash the rendered EventChain to get a deterministic function name. - chain_hash = md5(str(rendered_chain).encode("utf-8")).hexdigest() - memo_name = f"{event_trigger}_{chain_hash}" - - # Calculate Var dependencies accessed by the handler for useCallback dep array. - var_deps = ["addEvents", "ReflexEvent"] - - # Get deps from event trigger var data. - var_deps.extend(cls._get_deps_from_event_trigger(event)) - - # Get deps from hooks. - for arg in event_args: - var_data = arg._get_all_var_data() - if var_data is None: - continue - for hook in var_data.hooks: - var_deps.extend(cls._get_hook_deps(hook)) - memo_var_data = VarData.merge( - *[var._get_all_var_data() for var in event_args], - VarData( - imports={"react": [ImportVar(tag="useCallback")]}, - ), - ) - - # Store the memoized function name and hook code for this event trigger. - trigger_memo[event_trigger] = ( - Var(_js_expr=memo_name)._replace( - _var_type=EventChain, merge_var_data=memo_var_data - ), - f"const {memo_name} = useCallback({rendered_chain}, [{', '.join(var_deps)}])", - ) - return trigger_memo - - def _get_all_hooks_internal(self) -> dict[str, VarData | None]: - """Get the reflex internal hooks for the component and its children. - - Returns: - The code that should appear just before user-defined hooks. - """ - return {} - - def _get_all_hooks(self) -> dict[str, VarData | None]: - """Get the React hooks for this component. - - Returns: - The code that should appear just before returning the rendered component. - """ - return {} - - def _get_all_imports(self) -> ParsedImportDict: - """Get all the libraries and fields that are used by the component. - - Returns: - The import dict with the required imports. - """ - if self.rendered_as_shared: - return { - f"$/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}": [ - ImportVar(tag=self.tag) - ] - } - return self.component._get_all_imports() - - def _get_all_dynamic_imports(self) -> set[str]: - """Get dynamic imports for the component. - - Returns: - The dynamic imports. - """ - if self.rendered_as_shared: - return set() - return self.component._get_all_dynamic_imports() - - def _get_all_custom_code(self, export: bool = False) -> dict[str, None]: - """Get custom code for the component. - - Args: - export: Whether to export the component. - - Returns: - The custom code. - """ - if self.rendered_as_shared: - return {} - return self.component._get_all_custom_code() | ({ - self._render_stateful_code(export=export): None - }) - - def _get_all_refs(self) -> dict[str, None]: - """Get the refs for the children of the component. - - Returns: - The refs for the children. - """ - if self.rendered_as_shared: - return {} - return self.component._get_all_refs() - - def render(self) -> dict: - """Define how to render the component in React. - - Returns: - The tag to render. - """ - return dict(Tag(name=self.tag or "")) - - def __str__(self) -> str: - """Represent the component in React. - - Returns: - The code to render the component. - """ - from reflex.compiler.compiler import _compile_component - - return _compile_component(self) - - @classmethod - def compile_from(cls, component: BaseComponent) -> BaseComponent: - """Walk through the component tree and memoize all stateful components. - - Args: - component: The component to memoize. - - Returns: - The memoized component tree. - """ - if isinstance(component, Component): - if component._memoization_mode.recursive: - # Recursively memoize stateful children (default). - component.children = [ - cls.compile_from(child) for child in component.children - ] - # Memoize this component if it depends on state. - stateful_component = cls.create(component) - if stateful_component is not None: - return stateful_component - return component - - class MemoizationLeaf(Component): """A component that does not separately memoize its children. @@ -2823,30 +2452,15 @@ class MemoizationLeaf(Component): components within it, should be a memoization leaf so the compiler does not replace the provided child tags with memoized tags. - During creation, a memoization leaf will mark itself as wanting to be - memoized if any of its children return any hooks. + Whether the leaf is wrapped in a memo definition is decided by the + compiler's snapshot-boundary subtree scan, not by a class-local + disposition override — so leaves and components that explicitly set + ``_memoization_mode = MemoizationMode(recursive=False)`` are handled + identically. """ _memoization_mode = MemoizationMode(recursive=False) - @classmethod - def create(cls, *children, **props) -> Component: - """Create a new memoization leaf component. - - Args: - *children: The children of the component. - **props: The props of the component. - - Returns: - The memoization leaf - """ - comp = super().create(*children, **props) - if comp._get_all_hooks(): - comp._memoization_mode = dataclasses.replace( - comp._memoization_mode, disposition=MemoizationDisposition.ALWAYS - ) - return comp - load_dynamic_serializer() diff --git a/packages/reflex-base/src/reflex_base/components/dynamic.py b/packages/reflex-base/src/reflex_base/components/dynamic.py index 6c2100a40e8..a668bd341fc 100644 --- a/packages/reflex-base/src/reflex_base/components/dynamic.py +++ b/packages/reflex-base/src/reflex_base/components/dynamic.py @@ -26,14 +26,20 @@ def get_cdn_url(lib: str) -> str: return f"https://cdn.jsdelivr.net/npm/{lib}" + "/+esm" -bundled_libraries = [ +DEFAULT_BUNDLED_LIBRARIES = [ "react", - "@radix-ui/themes", "@emotion/react", f"$/{constants.Dirs.UTILS}/context", f"$/{constants.Dirs.UTILS}/state", f"$/{constants.Dirs.UTILS}/components", ] +bundled_libraries = list(DEFAULT_BUNDLED_LIBRARIES) + + +def reset_bundled_libraries() -> None: + """Reset the bundled library registry to its default values.""" + bundled_libraries.clear() + bundled_libraries.extend(DEFAULT_BUNDLED_LIBRARIES) def bundle_library(component: Union["Component", str]): @@ -46,7 +52,7 @@ def bundle_library(component: Union["Component", str]): DynamicComponentMissingLibraryError: Raised when a dynamic component is missing a library. """ if isinstance(component, str): - bundled_libraries.append(component) + bundled_libraries.append(format_library_name(component)) return if component.library is None: msg = "Component must have a library to bundle." @@ -85,9 +91,8 @@ def make_component(component: Component) -> str: rendered_components.update(component._get_all_custom_code()) rendered_components[ - templates.stateful_component_template( + templates.dynamic_component_template( tag_name="MySSRComponent", - memo_trigger_hooks=[], component=component, export=True, ) @@ -110,7 +115,7 @@ def make_component(component: Component) -> str: else: imports[lib] = names - module_code_lines = templates.stateful_components_template( + module_code_lines = templates.dynamic_components_module_template( imports=utils.compile_imports(imports), memoized_code="\n".join(rendered_components), ).splitlines() @@ -191,6 +196,7 @@ def evaluate_component(js_string: Var[str]) -> Var[Component]: imports.ImportVar(tag="evalReactComponent"), ], "react": [ + imports.ImportVar(tag="createElement"), imports.ImportVar(tag="useState"), imports.ImportVar(tag="useEffect"), ], @@ -202,7 +208,7 @@ def evaluate_component(js_string: Var[str]) -> Var[Component]: f"evalReactComponent({js_string!s})" ".then((component) => {" "if (isMounted) {" - f"set_{unique_var_name}(component);" + f"set_{unique_var_name}(() => createElement(component));" "}" "});" "return () => {" diff --git a/packages/reflex-base/src/reflex_base/components/memoize_helpers.py b/packages/reflex-base/src/reflex_base/components/memoize_helpers.py new file mode 100644 index 00000000000..9b74e881753 --- /dev/null +++ b/packages/reflex-base/src/reflex_base/components/memoize_helpers.py @@ -0,0 +1,295 @@ +"""Memoization helpers for auto-memoized and pseudo-stateful components. + +These helpers wrap a component's non-lifecycle event triggers in ``useCallback`` +so that React can skip re-renders of subtrees whose event handlers have stable +identities. They are used by both the compiler auto-memoization plugin (see +``reflex.compiler.plugins.memoize``) and by component-creation-time consumers +in ``reflex-components-core`` (e.g. ``WindowEventListener``, ``upload``). + +Auto-memoized components compile using one of two render strategies: + +- Passthrough memo bodies render the root component with a ``{children}`` hole. + The page still renders the descendants, which keeps root-level introspection + such as ``Form._get_form_refs`` working against the authored child tree. +- Snapshot memo bodies render the captured subtree in the memo module. This is + required for non-recursive memoization leaves and structural forms + (``Foreach``) whose stateful render logic belongs inside the memo component + rather than the containing page. +""" + +from __future__ import annotations + +import enum +from hashlib import md5 +from typing import TYPE_CHECKING + +from reflex_base.components.component import BaseComponent, Component +from reflex_base.constants import EventTriggers +from reflex_base.event import EventChain, EventSpec +from reflex_base.utils.imports import ImportVar +from reflex_base.vars import VarData +from reflex_base.vars.base import LiteralVar, Var +from reflex_base.vars.sequence import ArrayVar + +if TYPE_CHECKING: + from reflex_base.plugins.compiler import PageContext + + +class MemoizationStrategy(enum.Enum): + """How an auto-memo wrapper should render a component if it is memoized.""" + + PASSTHROUGH = "passthrough" + SNAPSHOT = "snapshot" + + +def _get_hook_deps(hook: str) -> list[str]: + """Extract Var deps from a hook declaration line. + + Args: + hook: The hook line (e.g. ``"const foo = useState(...)"``). + + Returns: + The names of variables created by the declaration. + """ + var_decl = hook.partition("=")[0].strip() + if not any(var_decl.startswith(kw) for kw in ["const ", "let ", "var "]): + return [] + _, _, var_name = var_decl.partition(" ") + var_name = var_name.strip() + if var_name.startswith(("[", "{")): + return [v.strip().replace("...", "") for v in var_name.strip("[]{}").split(",")] + return [var_name] + + +def _get_deps_from_event_trigger( + event: EventChain | EventSpec | Var, +) -> dict[str, None]: + """Get the dependencies accessed by an event trigger value. + + Args: + event: The event trigger value. + + Returns: + Dependency names, insertion-ordered. + """ + events: list = [event] + deps: dict[str, None] = {} + + if isinstance(event, EventChain): + events.extend(event.events) + + for ev in events: + if isinstance(ev, EventSpec): + for arg in ev.args: + for a in arg: + var_datas = VarData.merge(a._get_all_var_data()) + if var_datas and var_datas.deps is not None: + deps |= {str(dep): None for dep in var_datas.deps} + return deps + + +def get_memoized_event_triggers( + component: Component, +) -> dict[str, Var]: + """Generate ``useCallback`` wrappers for the component's event triggers. + + Args: + component: The component whose event triggers should be memoized. + + Returns: + A dict mapping event trigger name to memoized_triger. + """ + trigger_memo: dict[str, Var] = {} + for event_trigger, event_args in component._get_vars_from_event_triggers( + component.event_triggers + ): + if event_trigger in { + EventTriggers.ON_MOUNT, + EventTriggers.ON_UNMOUNT, + EventTriggers.ON_SUBMIT, + }: + # Do not memoize lifecycle or submit events. + continue + + event = component.event_triggers[event_trigger] + rendered_chain = LiteralVar.create(event) + + chain_hash = md5( + str(rendered_chain).encode("utf-8"), usedforsecurity=False + ).hexdigest() + memo_name = f"{event_trigger}_{chain_hash}" + + var_deps = ["addEvents", "ReflexEvent"] + var_deps.extend(_get_deps_from_event_trigger(event)) + + event_var_data = [] + for arg in event_args: + var_data = arg._get_all_var_data() + if var_data is None: + continue + event_var_data.append(var_data) + for hook in var_data.hooks: + var_deps.extend(_get_hook_deps(hook)) + + memo_var_data = VarData.merge( + *event_var_data, + rendered_chain._get_all_var_data(), + VarData( + hooks=[ + f"const {memo_name} = useCallback({rendered_chain!s}, [{', '.join(var_deps)}])" + ], + imports={"react": [ImportVar(tag="useCallback")]}, + ), + ) + + trigger_memo[event_trigger] = Var( + _js_expr=memo_name, _var_type=EventChain, _var_data=memo_var_data + ) + return trigger_memo + + +def fix_event_triggers_for_memo( + component: Component, page_context: PageContext +) -> Component: + """Return a component whose event triggers reference memoized ``useCallback``s. + + Replaces each (non-lifecycle) event-trigger value with a ``Var`` naming a + memoized ``useCallback`` wrapper. The original is never mutated — a + page-local clone is taken via ``page_context.own`` on first write. + + Args: + component: The component whose event triggers to memoize. + page_context: The active page context, used to obtain a page-local + clone before rewriting ``event_triggers``. + + Returns: + Either ``component`` (when nothing needed rewriting) or a page-local + clone with the rewritten ``event_triggers``. + """ + memo_event_triggers = tuple(get_memoized_event_triggers(component).items()) + if not memo_event_triggers: + return component + owned = page_context.own(component) + owned.event_triggers = { + **component.event_triggers, + **dict(memo_event_triggers), + } + return owned + + +def is_snapshot_boundary(component: Component) -> bool: + """Whether ``component`` owns its subtree for memoization purposes. + + Snapshot boundaries (``MemoizationLeaf``-style components with + ``_memoization_mode.recursive=False``) encapsulate internal machinery as + their own structural children. The auto-memoize compiler pass must wrap + them whole and not walk or independently memoize that subtree. + + The check is the behavioral flag, not ``isinstance(MemoizationLeaf)``, so + components that opt into non-recursive memoization without subclassing + ``MemoizationLeaf`` are handled identically. + + Args: + component: The component to classify. + + Returns: + ``True`` iff descendants of ``component`` must not be independently + memoized and the memo wrapper must carry the full subtree snapshot. + """ + return not component._memoization_mode.recursive + + +def _is_structural_memoization_child(component: Component) -> bool: + """Check whether ``component`` is a structural child for memoization. + + Args: + component: The child component to inspect. + + Returns: + True when the component's render body must stay inside the generated + memo body rather than flowing through as a normal ``children`` payload. + """ + from reflex_components_core.core.foreach import Foreach + + return bool(isinstance(component, Foreach)) + + +def _has_memoization_snapshot_child(component: Component) -> bool: + """Whether ``component`` has a structural child that needs a memo snapshot. + + Component-valued ``Foreach`` is a structural form, not an ordinary user + child. It must render its body in the memo module instead of flowing + through as a normal ``children`` payload. + + Args: + component: The component whose direct children should be inspected. + + Returns: + True when a direct child requires the parent memo wrapper to render a + captured snapshot. + """ + return any( + isinstance(child, Component) and _is_structural_memoization_child(child) + for child in component.children + ) + + +def passthrough_children_var( + children: list[BaseComponent], +) -> ArrayVar[list[BaseComponent]] | None: + """Return the placeholder ``children`` array Var if ``children`` is a memo hole. + + Auto-memo passthrough wrappers replace the wrapped component's children + with a single ``Bare(Var[Component](_js_expr="children"))`` placeholder + when compiling the memo body. Render-time consumers (notably ``Cond`` and + ``Match``) detect this and rewrite branches to index into the placeholder + array instead of capturing the original branch JSX in the memo body. The + returned Var is retyped to ``list[BaseComponent]`` so callers can index + it directly. + + Args: + children: The component's children list. + + Returns: + The placeholder Var (retyped as a Component list) for indexed access, + else ``None``. + """ + from reflex_components_core.base.bare import Bare + + if ( + len(children) == 1 + and isinstance(children[0], Bare) + and isinstance(children[0].contents, Var) + and children[0].contents._js_expr == "children" + ): + return children[0].contents.to(list[BaseComponent]) + return None + + +def get_memoization_strategy(component: Component) -> MemoizationStrategy: + """Get the render strategy for ``component`` if auto-memoization wraps it. + + Args: + component: The component being considered by auto-memoization. + + Returns: + The strategy to use when generating a memo wrapper. + """ + if ( + is_snapshot_boundary(component) + or _is_structural_memoization_child(component) + or _has_memoization_snapshot_child(component) + ): + return MemoizationStrategy.SNAPSHOT + + return MemoizationStrategy.PASSTHROUGH + + +__all__ = [ + "MemoizationStrategy", + "fix_event_triggers_for_memo", + "get_memoization_strategy", + "get_memoized_event_triggers", + "is_snapshot_boundary", + "passthrough_children_var", +] diff --git a/packages/reflex-base/src/reflex_base/environment.py b/packages/reflex-base/src/reflex_base/environment.py index 31ebe795998..f2d4ce44cb9 100644 --- a/packages/reflex-base/src/reflex_base/environment.py +++ b/packages/reflex-base/src/reflex_base/environment.py @@ -2,14 +2,11 @@ from __future__ import annotations -import concurrent.futures import dataclasses import enum import importlib -import multiprocessing import os -import platform -from collections.abc import Callable, Sequence +from collections.abc import Sequence from functools import lru_cache from pathlib import Path from typing import ( @@ -529,97 +526,6 @@ class PerformanceMode(enum.Enum): OFF = "off" -class ExecutorType(enum.Enum): - """Executor for compiling the frontend.""" - - THREAD = "thread" - PROCESS = "process" - MAIN_THREAD = "main_thread" - - @classmethod - def get_executor_from_environment(cls): - """Get the executor based on the environment variables. - - Returns: - The executor. - """ - from reflex_base.utils import console - - executor_type = environment.REFLEX_COMPILE_EXECUTOR.get() - - reflex_compile_processes = environment.REFLEX_COMPILE_PROCESSES.get() - reflex_compile_threads = environment.REFLEX_COMPILE_THREADS.get() - # By default, use the main thread. Unless the user has specified a different executor. - # Using a process pool is much faster, but not supported on all platforms. It's gated behind a flag. - if executor_type is None: - if ( - platform.system() not in ("Linux", "Darwin") - and reflex_compile_processes is not None - ): - console.warn("Multiprocessing is only supported on Linux and MacOS.") - - if ( - platform.system() in ("Linux", "Darwin") - and reflex_compile_processes is not None - ): - if reflex_compile_processes == 0: - console.warn( - "Number of processes must be greater than 0. If you want to use the default number of processes, set REFLEX_COMPILE_EXECUTOR to 'process'. Defaulting to None." - ) - reflex_compile_processes = None - elif reflex_compile_processes < 0: - console.warn( - "Number of processes must be greater than 0. Defaulting to None." - ) - reflex_compile_processes = None - executor_type = ExecutorType.PROCESS - elif reflex_compile_threads is not None: - if reflex_compile_threads == 0: - console.warn( - "Number of threads must be greater than 0. If you want to use the default number of threads, set REFLEX_COMPILE_EXECUTOR to 'thread'. Defaulting to None." - ) - reflex_compile_threads = None - elif reflex_compile_threads < 0: - console.warn( - "Number of threads must be greater than 0. Defaulting to None." - ) - reflex_compile_threads = None - executor_type = ExecutorType.THREAD - else: - executor_type = ExecutorType.MAIN_THREAD - - match executor_type: - case ExecutorType.PROCESS: - executor = concurrent.futures.ProcessPoolExecutor( - max_workers=reflex_compile_processes, - mp_context=multiprocessing.get_context("fork"), - ) - case ExecutorType.THREAD: - executor = concurrent.futures.ThreadPoolExecutor( - max_workers=reflex_compile_threads - ) - case ExecutorType.MAIN_THREAD: - FUTURE_RESULT_TYPE = TypeVar("FUTURE_RESULT_TYPE") - - class MainThreadExecutor: - def __enter__(self): - return self - - def __exit__(self, *args): - pass - - def submit( - self, fn: Callable[..., FUTURE_RESULT_TYPE], *args, **kwargs - ) -> concurrent.futures.Future[FUTURE_RESULT_TYPE]: - future_job = concurrent.futures.Future() - future_job.set_result(fn(*args, **kwargs)) - return future_job - - executor = MainThreadExecutor() - - return executor - - class EnvironmentVariables: """Environment variables class to instantiate environment variables.""" @@ -660,14 +566,6 @@ class EnvironmentVariables: Path(constants.Dirs.UPLOADED_FILES) ) - REFLEX_COMPILE_EXECUTOR: EnvVar[ExecutorType | None] = env_var(None) - - # Whether to use separate processes to compile the frontend and how many. If not set, defaults to thread executor. - REFLEX_COMPILE_PROCESSES: EnvVar[int | None] = env_var(None) - - # Whether to use separate threads to compile the frontend and how many. Defaults to `min(32, os.cpu_count() + 4)`. - REFLEX_COMPILE_THREADS: EnvVar[int | None] = env_var(None) - # The directory to store reflex dependencies. REFLEX_DIR: EnvVar[Path] = env_var(constants.Reflex.DIR) diff --git a/packages/reflex-base/src/reflex_base/plugins/__init__.py b/packages/reflex-base/src/reflex_base/plugins/__init__.py index b4320489b08..f3ef5aa971c 100644 --- a/packages/reflex-base/src/reflex_base/plugins/__init__.py +++ b/packages/reflex-base/src/reflex_base/plugins/__init__.py @@ -3,12 +3,26 @@ from . import sitemap, tailwind_v3, tailwind_v4 from ._screenshot import ScreenshotPlugin as _ScreenshotPlugin from .base import CommonContext, Plugin, PreCompileContext +from .compiler import ( + BaseContext, + CompileContext, + CompilerHooks, + ComponentAndChildren, + PageContext, + PageDefinition, +) from .sitemap import SitemapPlugin from .tailwind_v3 import TailwindV3Plugin from .tailwind_v4 import TailwindV4Plugin __all__ = [ + "BaseContext", "CommonContext", + "CompileContext", + "CompilerHooks", + "ComponentAndChildren", + "PageContext", + "PageDefinition", "Plugin", "PreCompileContext", "SitemapPlugin", diff --git a/packages/reflex-base/src/reflex_base/plugins/base.py b/packages/reflex-base/src/reflex_base/plugins/base.py index 52dfa8d7805..082258ddb9c 100644 --- a/packages/reflex-base/src/reflex_base/plugins/base.py +++ b/packages/reflex-base/src/reflex_base/plugins/base.py @@ -1,13 +1,25 @@ """Base class for all plugins.""" from collections.abc import Callable, Sequence +from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, ParamSpec, Protocol, TypedDict +from typing import TYPE_CHECKING, Any, ClassVar, ParamSpec, Protocol, TypedDict from typing_extensions import Unpack + +class HookOrder(str, Enum): + """Dispatch bucket for a compiler ``enter_component`` / ``leave_component`` hook.""" + + PRE = "pre" + NORMAL = "normal" + POST = "post" + + if TYPE_CHECKING: from reflex.app import App, UnevaluatedPage + from reflex_base.components.component import BaseComponent + from reflex_base.plugins.compiler import ComponentAndChildren, PageContext class CommonContext(TypedDict): @@ -41,6 +53,7 @@ class PreCompileContext(CommonContext): add_save_task: AddTaskProtocol add_modify_task: Callable[[str, Callable[[str], str]], None] + radix_themes_plugin: Any unevaluated_pages: Sequence["UnevaluatedPage"] @@ -53,6 +66,13 @@ class PostCompileContext(CommonContext): class Plugin: """Base class for all plugins.""" + # Dispatch position for ``enter_component`` and ``leave_component`` hooks. + # Plugins run in ``PRE`` → ``NORMAL`` → ``POST`` order. Within a bucket, + # enter hooks fire in plugin-chain order while leave hooks fire in + # reverse plugin-chain order (mirroring an enter/leave stack). + _compiler_enter_component_order: ClassVar[HookOrder] = HookOrder.NORMAL + _compiler_leave_component_order: ClassVar[HookOrder] = HookOrder.NORMAL + def get_frontend_development_dependencies( self, **context: Unpack[CommonContext] ) -> list[str] | set[str] | tuple[str, ...]: @@ -117,6 +137,78 @@ def post_compile(self, **context: Unpack[PostCompileContext]) -> None: context: The context for the plugin. """ + def eval_page( + self, + page_fn: Any, + /, + **kwargs: Any, + ) -> "PageContext | None": + """Evaluate a page-like object into a page context. + + Args: + page_fn: The page-like object to evaluate. + kwargs: Additional compiler-specific context. + + Returns: + A page context when the plugin can evaluate the page, otherwise ``None``. + """ + return None + + def compile_page( + self, + page_ctx: "PageContext", + /, + **kwargs: Any, + ) -> None: + """Finalize a page context after its component tree has been traversed.""" + return + + def enter_component( + self, + comp: "BaseComponent", + /, + *, + page_context: "PageContext", + compile_context: Any, + in_prop_tree: bool = False, + ) -> "BaseComponent | ComponentAndChildren | None": + """Inspect or transform a component before visiting its descendants. + + Args: + comp: The component being compiled. + page_context: The active page compilation state. + compile_context: The active compile-run state. + in_prop_tree: Whether the component is being visited through a prop subtree. + + Returns: + An optional replacement component and/or structural children. + """ + return None + + def leave_component( + self, + comp: "BaseComponent", + children: tuple["BaseComponent", ...], + /, + *, + page_context: "PageContext", + compile_context: Any, + in_prop_tree: bool = False, + ) -> "BaseComponent | ComponentAndChildren | None": + """Inspect or transform a component after visiting its descendants. + + Args: + comp: The component being compiled. + children: The compiled structural children for the component. + page_context: The active page compilation state. + compile_context: The active compile-run state. + in_prop_tree: Whether the component is being visited through a prop subtree. + + Returns: + An optional replacement component and/or structural children. + """ + return None + def __repr__(self): """Return a string representation of the plugin. diff --git a/packages/reflex-base/src/reflex_base/plugins/compiler.py b/packages/reflex-base/src/reflex_base/plugins/compiler.py new file mode 100644 index 00000000000..ecb55a03d92 --- /dev/null +++ b/packages/reflex-base/src/reflex_base/plugins/compiler.py @@ -0,0 +1,869 @@ +"""Compiler plugin infrastructure: protocols, contexts, and dispatch.""" + +from __future__ import annotations + +import copy +import dataclasses +import inspect +from collections.abc import Callable, Sequence +from contextvars import ContextVar, Token +from types import TracebackType +from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, TypeVar, cast + +from typing_extensions import Self + +from reflex_base.components.component import BaseComponent, Component +from reflex_base.utils.imports import ParsedImportDict, collapse_imports, merge_imports +from reflex_base.vars import VarData + +from .base import HookOrder, Plugin + +if TYPE_CHECKING: + from reflex.app import App, ComponentCallable + + PageComponent: TypeAlias = Component | ComponentCallable +else: + PageComponent: TypeAlias = ( + Component + | Callable[ + [], + Component | tuple[Component, ...] | str, + ] + ) + + +_BaseComponentT = TypeVar("_BaseComponentT", bound=BaseComponent) + + +class PageDefinition(Protocol): + """Protocol for page-like objects compiled by :class:`CompileContext`.""" + + @property + def route(self) -> str: + """Return the route for this page definition.""" + ... + + @property + def component(self) -> PageComponent: + """Return the component or callable for this page definition.""" + ... + + +ComponentAndChildren: TypeAlias = tuple[BaseComponent, tuple[BaseComponent, ...]] +ComponentReplacement: TypeAlias = BaseComponent | ComponentAndChildren | None +CompiledEnterHook: TypeAlias = Callable[ + [BaseComponent, bool], + ComponentReplacement, +] +CompiledLeaveHook: TypeAlias = Callable[ + [BaseComponent, tuple[BaseComponent, ...], bool], + ComponentReplacement, +] +EnterHookBinder: TypeAlias = Callable[ + ["PageContext", "CompileContext"], + CompiledEnterHook, +] +LeaveHookBinder: TypeAlias = Callable[ + ["PageContext", "CompileContext"], + CompiledLeaveHook, +] + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class CompilerHooks: + """Dispatch compiler hooks across an ordered plugin chain.""" + + plugins: tuple[Plugin, ...] = () + _eval_page_hooks: tuple[Callable[..., Any], ...] = dataclasses.field( + init=False, + repr=False, + ) + _compile_page_hooks: tuple[Callable[..., Any], ...] = dataclasses.field( + init=False, + repr=False, + ) + _enter_component_hook_binders: tuple[EnterHookBinder, ...] = dataclasses.field( + init=False, + repr=False, + ) + _leave_component_hook_binders: tuple[LeaveHookBinder, ...] = dataclasses.field( + init=False, + repr=False, + ) + _component_hooks_can_replace: bool = dataclasses.field( + init=False, + repr=False, + ) + + def __post_init__(self) -> None: + """Resolve the active compiler hook callables once.""" + object.__setattr__(self, "_eval_page_hooks", self._resolve_hooks("eval_page")) + object.__setattr__( + self, + "_compile_page_hooks", + self._resolve_hooks("compile_page"), + ) + enter_buckets: dict[HookOrder, list[EnterHookBinder]] = { + order: [] for order in HookOrder + } + leave_buckets: dict[HookOrder, list[LeaveHookBinder]] = { + order: [] for order in HookOrder + } + component_hooks_can_replace = False + + for plugin in self.plugins: + plugin_type = type(plugin) + if ( + hook_impl := self._get_hook_impl(plugin, "enter_component") + ) is not None: + enter_buckets[plugin_type._compiler_enter_component_order].append( + self._get_enter_hook_binder(plugin, hook_impl) + ) + component_hooks_can_replace = component_hooks_can_replace or bool( + getattr(plugin_type, "_compiler_can_replace_enter_component", True) + ) + + if ( + hook_impl := self._get_hook_impl(plugin, "leave_component") + ) is not None: + leave_buckets[plugin_type._compiler_leave_component_order].append( + self._get_leave_hook_binder(plugin, hook_impl) + ) + component_hooks_can_replace = component_hooks_can_replace or bool( + getattr(plugin_type, "_compiler_can_replace_leave_component", True) + ) + + object.__setattr__( + self, + "_enter_component_hook_binders", + tuple(binder for order in HookOrder for binder in enter_buckets[order]), + ) + object.__setattr__( + self, + "_leave_component_hook_binders", + tuple( + binder + for order in HookOrder + for binder in reversed(leave_buckets[order]) + ), + ) + object.__setattr__( + self, + "_component_hooks_can_replace", + component_hooks_can_replace, + ) + + @staticmethod + def _get_hook_impl( + plugin: Plugin, + hook_name: str, + ) -> Callable[..., Any] | None: + """Return the concrete hook implementation for a plugin, if any. + + Args: + plugin: The plugin to inspect. + hook_name: The hook attribute name. + + Returns: + The bound hook implementation, or ``None`` when the hook is inherited + unchanged from the default base implementation. + """ + plugin_impl = inspect.getattr_static(type(plugin), hook_name, None) + if plugin_impl is None: + return None + + if plugin_impl is inspect.getattr_static(Plugin, hook_name, None): + return None + + return cast(Callable[..., Any], getattr(plugin, hook_name, None)) + + def _resolve_hooks(self, hook_name: str) -> tuple[Callable[..., Any], ...]: + """Resolve concrete hook implementations for the plugin chain. + + Args: + hook_name: The hook attribute name. + + Returns: + The ordered concrete hook implementations for the hook. + """ + return tuple( + hook_impl + for plugin in self.plugins + if (hook_impl := self._get_hook_impl(plugin, hook_name)) is not None + ) + + @staticmethod + def _get_enter_hook_binder( + plugin: Plugin, + hook_impl: Callable[..., Any], + ) -> EnterHookBinder: + """Return a binder that produces a compiled enter-component hook.""" + if ( + binder := getattr(plugin, "_compiler_bind_enter_component", None) + ) is not None: + return cast(EnterHookBinder, binder) + + def bind( + page_context: PageContext, compile_context: CompileContext + ) -> CompiledEnterHook: + def enter_component( + comp: BaseComponent, + in_prop_tree: bool, + ) -> ComponentReplacement: + return cast( + ComponentReplacement, + hook_impl( + comp, + page_context=page_context, + compile_context=compile_context, + in_prop_tree=in_prop_tree, + ), + ) + + return enter_component + + return bind + + @staticmethod + def _get_leave_hook_binder( + plugin: Plugin, + hook_impl: Callable[..., Any], + ) -> LeaveHookBinder: + """Return a binder that produces a compiled leave-component hook.""" + if ( + binder := getattr(plugin, "_compiler_bind_leave_component", None) + ) is not None: + return cast(LeaveHookBinder, binder) + + def bind( + page_context: PageContext, compile_context: CompileContext + ) -> CompiledLeaveHook: + def leave_component( + comp: BaseComponent, + children: tuple[BaseComponent, ...], + in_prop_tree: bool, + ) -> ComponentReplacement: + return cast( + ComponentReplacement, + hook_impl( + comp, + children, + page_context=page_context, + compile_context=compile_context, + in_prop_tree=in_prop_tree, + ), + ) + + return leave_component + + return bind + + def eval_page( + self, + page_fn: PageComponent, + /, + *, + page: PageDefinition, + **kwargs: Any, + ) -> PageContext | None: + """Return the first page context produced by the plugin chain.""" + for hook_impl in self._eval_page_hooks: + result = hook_impl(page_fn, page=page, **kwargs) + if result is not None: + return cast(PageContext, result) + return None + + def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + """Run all ``compile_page`` hooks in plugin order.""" + for hook_impl in self._compile_page_hooks: + hook_impl(page_ctx, **kwargs) + + def compile_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> BaseComponent: + """Walk a component tree once while dispatching cached enter/leave hooks. + + Returns: + The compiled component root for this subtree. + """ + enter_hooks = tuple( + hook_binder(page_context, compile_context) + for hook_binder in self._enter_component_hook_binders + ) + + if not self._component_hooks_can_replace: + leave_hooks = tuple( + hook_binder(page_context, compile_context) + for hook_binder in self._leave_component_hook_binders + ) + + if len(enter_hooks) == 1 and not leave_hooks: + return self._compile_component_single_enter_fast_path( + comp, + enter_hook=enter_hooks[0], + page_context=page_context, + in_prop_tree=in_prop_tree, + ) + + return self._compile_component_without_replacements( + comp, + enter_hooks=enter_hooks, + leave_hooks=leave_hooks, + page_context=page_context, + in_prop_tree=in_prop_tree, + ) + + return self._compile_component_with_replacements( + comp, + enter_hooks=enter_hooks, + leave_hooks=tuple( + hook_binder(page_context, compile_context) + for hook_binder in self._leave_component_hook_binders + ), + page_context=page_context, + in_prop_tree=in_prop_tree, + ) + + def _compile_component_without_replacements( + self, + comp: BaseComponent, + /, + *, + enter_hooks: tuple[CompiledEnterHook, ...], + leave_hooks: tuple[CompiledLeaveHook, ...], + page_context: PageContext, + in_prop_tree: bool = False, + ) -> BaseComponent: + """Walk a component tree when hook plans only observe state. + + Returns: + The compiled component root for this subtree. + """ + + def visit( + current_comp: BaseComponent, + current_in_prop_tree: bool, + ) -> BaseComponent: + for hook_impl in enter_hooks: + hook_impl( + current_comp, + current_in_prop_tree, + ) + + updated_children: list[BaseComponent] | None = None + children = current_comp.children + for index, child in enumerate(children): + compiled_child = visit( + child, + current_in_prop_tree, + ) + if updated_children is None: + if compiled_child is child: + continue + updated_children = list(children[:index]) + updated_children.append(compiled_child) + if updated_children is not None: + current_comp = page_context.own(current_comp) + current_comp.children = updated_children + + if isinstance(current_comp, Component): + for prop_component in current_comp._get_components_in_props(): + visit( + prop_component, + True, + ) + + if leave_hooks: + compiled_children = tuple(current_comp.children) + for hook_impl in leave_hooks: + hook_impl( + current_comp, + compiled_children, + current_in_prop_tree, + ) + + return current_comp + + return visit( + comp, + in_prop_tree, + ) + + def _compile_component_single_enter_fast_path( + self, + comp: BaseComponent, + /, + *, + enter_hook: CompiledEnterHook, + page_context: PageContext, + in_prop_tree: bool = False, + ) -> BaseComponent: + """Walk a component tree for the common one-enter-hook fast path. + + Returns: + The compiled component root for this subtree. + """ + + def visit( + current_comp: BaseComponent, + current_in_prop_tree: bool, + ) -> BaseComponent: + enter_hook( + current_comp, + current_in_prop_tree, + ) + + updated_children: list[BaseComponent] | None = None + children = current_comp.children + for index, child in enumerate(children): + compiled_child = visit( + child, + current_in_prop_tree, + ) + if updated_children is None: + if compiled_child is child: + continue + updated_children = list(children[:index]) + updated_children.append(compiled_child) + if updated_children is not None: + current_comp = page_context.own(current_comp) + current_comp.children = updated_children + + if isinstance(current_comp, Component): + for prop_component in current_comp._get_components_in_props(): + visit( + prop_component, + True, + ) + + return current_comp + + return visit( + comp, + in_prop_tree, + ) + + def _compile_component_with_replacements( + self, + comp: BaseComponent, + /, + *, + enter_hooks: tuple[CompiledEnterHook, ...], + leave_hooks: tuple[CompiledLeaveHook, ...], + page_context: PageContext, + in_prop_tree: bool = False, + ) -> BaseComponent: + """Walk a component tree while honoring hook replacements. + + Returns: + The compiled component root for this subtree. + """ + apply_replacement = self._apply_replacement + + def visit_children( + children: Sequence[BaseComponent], + current_in_prop_tree: bool, + ) -> tuple[BaseComponent, ...]: + if not children: + return () + + updated_children: list[BaseComponent] | None = None + for index, child in enumerate(children): + compiled_child = visit( + child, + current_in_prop_tree, + ) + if updated_children is None: + if compiled_child is child: + continue + updated_children = list(children[:index]) + updated_children.append(compiled_child) + if updated_children is None: + return children if isinstance(children, tuple) else tuple(children) + return tuple(updated_children) + + def visit( + current_comp: BaseComponent, + current_in_prop_tree: bool, + ) -> BaseComponent: + compiled_component = current_comp + structural_children: tuple[BaseComponent, ...] | None = None + + for hook_impl in enter_hooks: + compiled_component, structural_children = apply_replacement( + compiled_component, + structural_children, + hook_impl( + compiled_component, + current_in_prop_tree, + ), + ) + + if structural_children is None: + structural_children = tuple(compiled_component.children) + compiled_children = visit_children( + structural_children, + current_in_prop_tree, + ) + if isinstance(compiled_component, Component): + for prop_component in compiled_component._get_components_in_props(): + visit( + prop_component, + True, + ) + + for hook_impl in leave_hooks: + compiled_component, replacement_children = apply_replacement( + compiled_component, + compiled_children, + hook_impl( + compiled_component, + compiled_children, + current_in_prop_tree, + ), + ) + if replacement_children is not compiled_children: + assert replacement_children is not None + # Re-walking fires enter/leave again on any child objects + # carried over from the original children tuple. Observing + # collectors dedupe by dict key, so this is idempotent for + # today's plugins; stateful side effects on the page + # context would be double-applied. + compiled_children = visit_children( + replacement_children, + current_in_prop_tree, + ) + + current = compiled_component.children + if len(compiled_children) != len(current) or any( + a is not b for a, b in zip(compiled_children, current, strict=True) + ): + compiled_component = page_context.own(compiled_component) + compiled_component.children = list(compiled_children) + return compiled_component + + return visit( + comp, + in_prop_tree, + ) + + @staticmethod + def _apply_replacement( + comp: BaseComponent, + children: tuple[BaseComponent, ...] | None, + replacement: ComponentReplacement, + ) -> tuple[BaseComponent, tuple[BaseComponent, ...] | None]: + """Apply a plugin replacement to the current component state. + + Args: + comp: The current component. + children: The current structural children. + replacement: The plugin-supplied replacement. + + Returns: + The updated component and structural children pair. + """ + if replacement is None: + return comp, children + if isinstance(replacement, tuple): + return replacement + return replacement, children + + +@dataclasses.dataclass(kw_only=True) +class BaseContext: + """Context manager that exposes itself through a class-local context var.""" + + __context_var__: ClassVar[ContextVar[Self | None]] + + _attached_context_token: Token[Self | None] | None = dataclasses.field( + default=None, + init=False, + repr=False, + ) + + @classmethod + def __init_subclass__(cls, **kwargs: Any) -> None: + """Initialize a dedicated context variable for each subclass.""" + super().__init_subclass__(**kwargs) + cls.__context_var__ = ContextVar(cls.__name__, default=None) + + @classmethod + def get(cls) -> Self: + """Return the active context instance for the current task. + + Returns: + The active context instance for the current task. + """ + context = cls.__context_var__.get() + if context is None: + msg = f"No active {cls.__name__} is attached to the current context." + raise RuntimeError(msg) + return context + + def __enter__(self) -> Self: + """Attach this context to the current task. + + Returns: + The attached context instance. + """ + if self._attached_context_token is not None: + msg = "Context is already attached and cannot be entered twice." + raise RuntimeError(msg) + self._attached_context_token = type(self).__context_var__.set(self) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Detach this context from the current task.""" + del exc_type, exc_val, exc_tb + if self._attached_context_token is None: + return + try: + type(self).__context_var__.reset(self._attached_context_token) + finally: + self._attached_context_token = None + + async def __aenter__(self) -> Self: + """Attach this context to the current task asynchronously. + + Returns: + The attached context instance. + """ + return self.__enter__() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Detach this context from the current task asynchronously.""" + self.__exit__(exc_type, exc_val, exc_tb) + + def ensure_context_attached(self) -> None: + """Ensure this instance is the active context for the current task.""" + try: + current = type(self).get() + except RuntimeError as err: + msg = ( + f"{type(self).__name__} must be entered with 'with' or 'async with' " + "before calling this method." + ) + raise RuntimeError(msg) from err + if current is not self: + msg = f"{type(self).__name__} is not attached to the current task context." + raise RuntimeError(msg) + + +@dataclasses.dataclass(slots=True, kw_only=True) +class PageContext(BaseContext): + """Mutable compilation state for a single page.""" + + name: str + route: str + root_component: BaseComponent + imports: list[ParsedImportDict] = dataclasses.field(default_factory=list) + module_code: dict[str, None] = dataclasses.field(default_factory=dict) + hooks: dict[str, VarData | None] = dataclasses.field(default_factory=dict) + dynamic_imports: set[str] = dataclasses.field(default_factory=set) + refs: dict[str, None] = dataclasses.field(default_factory=dict) + app_wrap_components: dict[tuple[int, str], Component] = dataclasses.field( + default_factory=dict + ) + frontend_imports: ParsedImportDict = dataclasses.field(default_factory=dict) + output_path: str | None = None + output_code: str | None = None + # Stack of ``id(component)`` for components whose subtree is + # memoize-suppressed. Populated by ``MemoizeStatefulPlugin`` when it + # encounters a ``MemoizationLeaf``-style snapshot boundary and popped on + # the matching ``leave_component``. Non-empty iff we are inside such a + # subtree. + memoize_suppressor_stack: list[int] = dataclasses.field(default_factory=list) + # Maps both the user-owned original's ``id()`` and the clone's ``id()`` to + # the page-local clone. Lets the walker and plugins rebind children, style, + # or event_triggers on a page-local copy without mutating a user-owned + # instance that may be referenced from another route. + _owned: dict[int, BaseComponent] = dataclasses.field(default_factory=dict) + # Strong references to originals keyed by ``id()`` above. Without these, + # an original that is only reachable through ``_owned``'s int key can be + # garbage collected, and Python may recycle its ``id()`` for a fresh + # component, causing ``own()`` to hand back the wrong clone. + _owned_refs: list[BaseComponent] = dataclasses.field(default_factory=list) + + def own(self, comp: _BaseComponentT) -> _BaseComponentT: + """Return a page-local copy of ``comp``, cloning on first encounter. + + Repeated calls with the same original return the same clone, so + mutations from several plugins accumulate on one instance. + + Args: + comp: The component the caller is about to mutate. + + Returns: + A component the caller may freely mutate without touching any + user-owned instance. + """ + existing = self._owned.get(id(comp)) + if existing is not None: + return cast("_BaseComponentT", existing) + new = copy.copy(comp) + self._owned[id(comp)] = new + self._owned[id(new)] = new + self._owned_refs.append(comp) + return new + + def merged_imports(self, *, collapse: bool = False) -> ParsedImportDict: + """Return the imports accumulated for this page. + + Args: + collapse: Whether to collapse duplicate imports. + + Returns: + The merged page imports. + """ + imports = merge_imports(*self.imports) if self.imports else {} + return collapse_imports(imports) if collapse else imports + + def custom_code_dict(self) -> dict[str, None]: + """Return custom-code snippets keyed like legacy collectors. + + Returns: + The page custom code keyed by snippet. + """ + return dict(self.module_code) + + +@dataclasses.dataclass(slots=True, kw_only=True) +class CompileContext(BaseContext): + """Mutable compilation state for an entire compile run.""" + + app: App | None = None + pages: Sequence[PageDefinition] + hooks: CompilerHooks = dataclasses.field(default_factory=CompilerHooks) + compiled_pages: dict[str, PageContext] = dataclasses.field(default_factory=dict) + all_imports: ParsedImportDict = dataclasses.field(default_factory=dict) + app_wrap_components: dict[tuple[int, str], Component] = dataclasses.field( + default_factory=dict + ) + stateful_routes: dict[str, None] = dataclasses.field(default_factory=dict) + # Auto-memoize wrapper tags seen during the tree walk (populated by + # ``MemoizeStatefulPlugin``). + memoize_wrappers: dict[str, None] = dataclasses.field(default_factory=dict) + # Compiler-generated experimental memo definitions for auto-memoized + # stateful wrappers. Stored as ``Any`` to keep ``reflex_base`` decoupled + # from ``reflex.experimental.memo``. + auto_memo_components: dict[str, Any] = dataclasses.field(default_factory=dict) + + def compile( + self, + *, + evaluate_progress: Callable[[], None] | None = None, + render_progress: Callable[[], None] | None = None, + **kwargs: Any, + ) -> dict[str, PageContext]: + """Compile all configured pages through the plugin pipeline. + + Args: + evaluate_progress: Callback invoked after each page evaluation. + render_progress: Callback invoked after each page render. + kwargs: Additional compiler-specific context. + + Returns: + The compiled page contexts keyed by route. + """ + from reflex.compiler import compiler + from reflex.state import all_base_state_classes + + self.ensure_context_attached() + self.compiled_pages.clear() + self.all_imports.clear() + self.app_wrap_components.clear() + self.stateful_routes.clear() + self.memoize_wrappers.clear() + self.auto_memo_components.clear() + + for page in self.pages: + page_fn = page.component + n_states_before = len(all_base_state_classes) + page_ctx = self.hooks.eval_page( + page_fn, + page=page, + compile_context=self, + **kwargs, + ) + if page_ctx is None: + page_name = getattr(page_fn, "__name__", repr(page_fn)) + msg = ( + f"No compiler plugin was able to evaluate page {page.route!r} " + f"({page_name})." + ) + raise RuntimeError(msg) + if page_ctx.route in self.compiled_pages: + msg = f"Duplicate compiled page route {page_ctx.route!r}." + raise RuntimeError(msg) + + if len(all_base_state_classes) > n_states_before: + self.stateful_routes[page.route] = None + + self.compiled_pages[page_ctx.route] = page_ctx + + if evaluate_progress is not None: + evaluate_progress() + + for page, page_ctx in zip( + self.pages, + self.compiled_pages.values(), + strict=True, + ): + with page_ctx: + page_ctx.root_component = self.hooks.compile_component( + page_ctx.root_component, + page_context=page_ctx, + compile_context=self, + ) + self.hooks.compile_page( + page_ctx, + page=page, + compile_context=self, + **kwargs, + ) + + page_ctx.frontend_imports = page_ctx.merged_imports(collapse=True) + self.all_imports = merge_imports( + self.all_imports, page_ctx.frontend_imports + ) + self.app_wrap_components.update(page_ctx.app_wrap_components) + page_ctx.output_path, page_ctx.output_code = ( + compiler.compile_page_from_context(page_ctx) + ) + + if render_progress is not None: + render_progress() + + return self.compiled_pages + + +__all__ = [ + "BaseContext", + "CompileContext", + "CompilerHooks", + "ComponentAndChildren", + "PageContext", + "PageDefinition", +] diff --git a/packages/reflex-base/src/reflex_base/plugins/tailwind_v3.py b/packages/reflex-base/src/reflex_base/plugins/tailwind_v3.py index d67264fc0e3..66f575db5c2 100644 --- a/packages/reflex-base/src/reflex_base/plugins/tailwind_v3.py +++ b/packages/reflex-base/src/reflex_base/plugins/tailwind_v3.py @@ -29,7 +29,7 @@ class Constants(SimpleNamespace): ROOT_STYLE_CONTENT = """ @import "tailwindcss/base"; -@import url('{radix_url}'); +{radix_import} @tailwind components; @tailwind utilities; @@ -54,9 +54,12 @@ def compile_config(config: TailwindConfig): ) -def compile_root_style(): +def compile_root_style(include_radix_themes: bool = True): """Compile the Tailwind root style. + Args: + include_radix_themes: Whether to include the Radix stylesheet import. + Returns: The compiled Tailwind root style. """ @@ -65,7 +68,9 @@ def compile_root_style(): return str( Path(Dirs.STYLES) / Constants.ROOT_STYLE_PATH ), Constants.ROOT_STYLE_CONTENT.format( - radix_url=RADIX_THEMES_STYLESHEET, + radix_import=( + f"@import url('{RADIX_THEMES_STYLESHEET}');" if include_radix_themes else "" + ), ) @@ -112,11 +117,14 @@ def add_tailwind_to_postcss_config(postcss_file_content: str) -> str: return "\n".join(postcss_file_lines) -def add_tailwind_to_css_file(css_file_content: str) -> str: +def add_tailwind_to_css_file( + css_file_content: str, *, include_radix_themes: bool = True +) -> str: """Add tailwind to the css file. Args: css_file_content: The content of the css file. + include_radix_themes: Whether the root stylesheet already imports Radix. Returns: The modified css file content. @@ -125,16 +133,23 @@ def add_tailwind_to_css_file(css_file_content: str) -> str: if Constants.TAILWIND_CSS.splitlines()[0] in css_file_content: return css_file_content - if RADIX_THEMES_STYLESHEET not in css_file_content: - print( # noqa: T201 - f"Could not find line with '{RADIX_THEMES_STYLESHEET}' in {Dirs.STYLES}. " - "Please make sure the file exists and is valid." + if include_radix_themes and RADIX_THEMES_STYLESHEET in css_file_content: + return css_file_content.replace( + f"@import url('{RADIX_THEMES_STYLESHEET}');", + Constants.TAILWIND_CSS, ) - return css_file_content - return css_file_content.replace( - f"@import url('{RADIX_THEMES_STYLESHEET}');", - Constants.TAILWIND_CSS, + + lines = css_file_content.splitlines() + insert_at = next( + ( + index + 1 + for index, line in enumerate(lines) + if "__reflex_style_reset.css" in line + ), + 1, ) + lines.insert(insert_at, Constants.TAILWIND_CSS) + return "\n".join(lines) @dataclasses.dataclass @@ -162,9 +177,14 @@ def pre_compile(self, **context): context: The context for the plugin. """ context["add_save_task"](compile_config, self.get_unversioned_config()) - context["add_save_task"](compile_root_style) + include_radix_themes = context["radix_themes_plugin"].enabled + + context["add_save_task"](compile_root_style, include_radix_themes) context["add_modify_task"](Dirs.POSTCSS_JS, add_tailwind_to_postcss_config) context["add_modify_task"]( str(Path(Dirs.STYLES) / (PageNames.STYLESHEET_ROOT + Ext.CSS)), - add_tailwind_to_css_file, + lambda content: add_tailwind_to_css_file( + content, + include_radix_themes=include_radix_themes, + ), ) diff --git a/packages/reflex-base/src/reflex_base/plugins/tailwind_v4.py b/packages/reflex-base/src/reflex_base/plugins/tailwind_v4.py index a3f1c24eb5b..7c11dc84c2a 100644 --- a/packages/reflex-base/src/reflex_base/plugins/tailwind_v4.py +++ b/packages/reflex-base/src/reflex_base/plugins/tailwind_v4.py @@ -29,8 +29,7 @@ class Constants(SimpleNamespace): ROOT_STYLE_CONTENT = """@layer theme, base, components, utilities; @import "tailwindcss/theme.css" layer(theme); @import "tailwindcss/preflight.css" layer(base); -@import "{radix_url}" layer(components); -@import "tailwindcss/utilities.css" layer(utilities); +{radix_import}@import "tailwindcss/utilities.css" layer(utilities); @config "../tailwind.config.js"; """ @@ -53,9 +52,12 @@ def compile_config(config: TailwindConfig): ) -def compile_root_style(): +def compile_root_style(include_radix_themes: bool = True): """Compile the Tailwind root style. + Args: + include_radix_themes: Whether to include the Radix stylesheet import. + Returns: The compiled Tailwind root style. """ @@ -64,7 +66,11 @@ def compile_root_style(): return str( Path(Dirs.STYLES) / Constants.ROOT_STYLE_PATH ), Constants.ROOT_STYLE_CONTENT.format( - radix_url=RADIX_THEMES_STYLESHEET, + radix_import=( + f'@import "{RADIX_THEMES_STYLESHEET}" layer(components);\n' + if include_radix_themes + else "" + ), ) @@ -115,11 +121,14 @@ def add_tailwind_to_postcss_config(postcss_file_content: str) -> str: return "\n".join(postcss_file_lines) -def add_tailwind_to_css_file(css_file_content: str) -> str: +def add_tailwind_to_css_file( + css_file_content: str, *, include_radix_themes: bool = True +) -> str: """Add tailwind to the css file. Args: css_file_content: The content of the css file. + include_radix_themes: Whether the root stylesheet already imports Radix. Returns: The modified css file content. @@ -128,16 +137,23 @@ def add_tailwind_to_css_file(css_file_content: str) -> str: if Constants.TAILWIND_CSS.splitlines()[0] in css_file_content: return css_file_content - if RADIX_THEMES_STYLESHEET not in css_file_content: - print( # noqa: T201 - f"Could not find line with '{RADIX_THEMES_STYLESHEET}' in {Dirs.STYLES}. " - "Please make sure the file exists and is valid." + if include_radix_themes and RADIX_THEMES_STYLESHEET in css_file_content: + return css_file_content.replace( + f"@import url('{RADIX_THEMES_STYLESHEET}');", + Constants.TAILWIND_CSS, ) - return css_file_content - return css_file_content.replace( - f"@import url('{RADIX_THEMES_STYLESHEET}');", - Constants.TAILWIND_CSS, + + lines = css_file_content.splitlines() + insert_at = next( + ( + index + 1 + for index, line in enumerate(lines) + if "__reflex_style_reset.css" in line + ), + 1, ) + lines.insert(insert_at, Constants.TAILWIND_CSS) + return "\n".join(lines) @dataclasses.dataclass @@ -166,9 +182,14 @@ def pre_compile(self, **context): context: The context for the plugin. """ context["add_save_task"](compile_config, self.get_unversioned_config()) - context["add_save_task"](compile_root_style) + include_radix_themes = context["radix_themes_plugin"].enabled + + context["add_save_task"](compile_root_style, include_radix_themes) context["add_modify_task"](Dirs.POSTCSS_JS, add_tailwind_to_postcss_config) context["add_modify_task"]( str(Path(Dirs.STYLES) / (PageNames.STYLESHEET_ROOT + Ext.CSS)), - add_tailwind_to_css_file, + lambda content: add_tailwind_to_css_file( + content, + include_radix_themes=include_radix_themes, + ), ) diff --git a/packages/reflex-base/src/reflex_base/registry.py b/packages/reflex-base/src/reflex_base/registry.py index 8caa1d2b2c3..71b4d723e5e 100644 --- a/packages/reflex-base/src/reflex_base/registry.py +++ b/packages/reflex-base/src/reflex_base/registry.py @@ -12,7 +12,6 @@ if TYPE_CHECKING: from reflex.state import BaseState - from reflex_base.components.component import StatefulComponent from reflex_base.event import EventHandler @@ -40,10 +39,6 @@ class RegistrationContext(BaseContext): default_factory=dict, repr=False, ) - tag_to_stateful_component: dict[str, StatefulComponent] = dataclasses.field( - default_factory=dict, - repr=False, - ) @classmethod def ensure_context(cls) -> Self: diff --git a/packages/reflex-base/src/reflex_base/utils/console.py b/packages/reflex-base/src/reflex_base/utils/console.py index b7a9539f77b..962008e42e0 100644 --- a/packages/reflex-base/src/reflex_base/utils/console.py +++ b/packages/reflex-base/src/reflex_base/utils/console.py @@ -479,6 +479,18 @@ def advance(self, task: TaskID, advance: int = 1): self.progress += advance _console.print(f"Progress: {self.progress}/{self.total}") + def update(self, task: TaskID, total: int | None = None): + """Update properties of a task. + + Args: + task: The task ID. + total: New total for the task. + """ + if total is not None and task in self.tasks: + previous_total = self.tasks[task]["total"] + self.tasks[task]["total"] = total + self.total += total - previous_total + def start(self): """Start the progress bar.""" diff --git a/packages/reflex-base/src/reflex_base/utils/streaming_response.py b/packages/reflex-base/src/reflex_base/utils/streaming_response.py index d9907379cef..66c6ab7fc62 100644 --- a/packages/reflex-base/src/reflex_base/utils/streaming_response.py +++ b/packages/reflex-base/src/reflex_base/utils/streaming_response.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import builtins import contextlib import sys @@ -60,14 +59,16 @@ def _collapse_excgroups() -> Generator[None, None, None]: class DisconnectAwareStreamingResponse(StreamingResponse): - """Streaming response that cancels its body task on disconnect.""" + """Streaming response with a guaranteed finish callback.""" _on_finish: Callable[[], Awaitable[None]] + _on_disconnect: Callable[[], None] | None def __init__( self, *args: Any, on_finish: Callable[[], Awaitable[None]], + on_disconnect: Callable[[], None] | None = None, **kwargs: Any, ) -> None: """Initialize the response. @@ -75,17 +76,17 @@ def __init__( Args: args: Positional args forwarded to ``StreamingResponse``. on_finish: Cleanup callback to run exactly once when the response ends. + on_disconnect: Sync callback invoked when the client disconnects. kwargs: Keyword args forwarded to ``StreamingResponse``. """ super().__init__(*args, **kwargs) self._on_finish = on_finish + self._on_disconnect = on_disconnect - async def _watch_disconnect(self, receive: Receive) -> None: - """Wait for the client connection to close.""" - while True: - message = await receive() - if message["type"] == "http.disconnect": - return + def _notify_disconnect(self) -> None: + """Invoke the on_disconnect callback if one was provided.""" + if self._on_disconnect is not None: + self._on_disconnect() async def _close_body_iterator(self) -> None: """Close the body iterator if it supports ``aclose``.""" @@ -94,7 +95,7 @@ async def _close_body_iterator(self) -> None: await aclose() async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """Serve the response and cancel the body task on disconnect.""" + """Serve the response and always run the finish callback.""" spec_version = _parse_asgi_spec_version(scope) try: @@ -107,47 +108,24 @@ async def wrap(func: Callable[[], Awaitable[None]]) -> None: task_group.cancel_scope.cancel() task_group.start_soon(wrap, partial(self.stream_response, send)) - await wrap(partial(self.listen_for_disconnect, receive)) - else: - # Verified against Starlette 0.52.1: the ASGI >= 2.4 path in - # StreamingResponse.__call__ delegates straight to - # stream_response(send) and does not read from receive(). - # Keep calling stream_response(send) directly here so the - # disconnect watcher remains the only receive() consumer; if - # Starlette changes that contract, re-check this logic. - stream_task = asyncio.create_task(self.stream_response(send)) - disconnect_task = asyncio.create_task(self._watch_disconnect(receive)) - should_close_body_iterator = False + if self._on_disconnect is not None: + + async def _disconnect_then_notify() -> None: + await self.listen_for_disconnect(receive) + self._notify_disconnect() + + await wrap(_disconnect_then_notify) + else: + await wrap(partial(self.listen_for_disconnect, receive)) + else: try: - done, _ = await asyncio.wait( - {stream_task, disconnect_task}, - return_when=asyncio.FIRST_COMPLETED, - ) - if disconnect_task in done and not stream_task.done(): - should_close_body_iterator = True - stream_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await stream_task - else: - try: - await stream_task - except OSError as err: - should_close_body_iterator = True - raise ClientDisconnect from err - finally: - if not disconnect_task.done(): - disconnect_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await disconnect_task - if not stream_task.done(): - should_close_body_iterator = True - stream_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await stream_task - if should_close_body_iterator: - await self._close_body_iterator() + await self.stream_response(send) + except OSError as err: + self._notify_disconnect() + raise ClientDisconnect from err finally: + await self._close_body_iterator() await self._on_finish() if self.background is not None: diff --git a/packages/reflex-components-core/src/reflex_components_core/base/link.py b/packages/reflex-components-core/src/reflex_components_core/base/link.py index a3a7b89dfbf..78bd301bcb2 100644 --- a/packages/reflex-components-core/src/reflex_components_core/base/link.py +++ b/packages/reflex-components-core/src/reflex_components_core/base/link.py @@ -3,10 +3,10 @@ from reflex_base.components.component import field from reflex_base.vars.base import Var -from reflex_components_core.el.elements.base import BaseHTML +from reflex_components_core.el.elements.base import RawTextBaseHTML, VoidBaseHTML -class RawLink(BaseHTML): +class RawLink(VoidBaseHTML): """A component that displays the title of the current page.""" tag = "link" @@ -16,7 +16,7 @@ class RawLink(BaseHTML): rel: Var[str] = field(doc="The type of link.") -class ScriptTag(BaseHTML): +class ScriptTag(RawTextBaseHTML): """A script tag with the specified type and source.""" tag = "script" diff --git a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py index 0674e071fbd..680fd7c613f 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py @@ -495,12 +495,22 @@ def _create_upload_event() -> Event: msg = "Upload event was not created." raise RuntimeError(msg) + disconnect_seen = False + + def _mark_disconnected() -> None: + nonlocal disconnect_seen + disconnect_seen = True + async def _ndjson_updates(): """Process the upload event, generating ndjson updates. Yields: Each state update as newline-delimited JSON. """ + # Let the disconnect watcher run before we enqueue the upload handler. + await asyncio.sleep(0) + if disconnect_seen: + return # Enqueue the task on the main event loop, but emit deltas to the local queue. async for delta in app.event_processor.enqueue_stream_delta(token, event): yield json_dumps(StateUpdate(delta=delta)) + "\n" @@ -509,6 +519,7 @@ async def _ndjson_updates(): _ndjson_updates(), media_type="application/x-ndjson", on_finish=_close_form_data, + on_disconnect=_mark_disconnected, ) diff --git a/packages/reflex-components-core/src/reflex_components_core/core/cond.py b/packages/reflex-components-core/src/reflex_components_core/core/cond.py index 98ba56d170d..a35443cb17c 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/cond.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/cond.py @@ -5,6 +5,7 @@ from typing import Any, TypeVar, overload from reflex_base.components.component import BaseComponent, Component, field +from reflex_base.components.memoize_helpers import passthrough_children_var from reflex_base.components.tags import CondTag, Tag from reflex_base.constants import Dirs from reflex_base.style import LIGHT_COLOR_MODE, resolved_color_mode @@ -14,6 +15,7 @@ from reflex_base.vars.base import LiteralVar, Var from reflex_base.vars.number import ternary_operation +from reflex_components_core.base.bare import Bare from reflex_components_core.base.fragment import Fragment _IS_TRUE_IMPORT: ImportDict = { @@ -26,6 +28,28 @@ class Cond(Component): cond: Var[Any] = field(doc="The cond to determine which component to render.") + def _get_cond_children(self) -> tuple[BaseComponent, BaseComponent]: + """Get true and false branch components with safe defaults. + + When rendering inside an auto-memo passthrough body, ``self.children`` + is collapsed to a single ``Bare`` holding the placeholder ``children`` + Var. The branches are reconstructed as indexed accesses + (``children[0]`` / ``children[1]``) so the page-side renders the real + branch JSX and the memo body just selects which one to mount. + + Returns: + A tuple containing true and false branch components. + """ + children_var = passthrough_children_var(self.children) + if children_var is not None: + return ( + Bare.create(children_var[0]), + Bare.create(children_var[1]), + ) + true_child = self.children[0] if self.children else Fragment.create() + false_child = self.children[1] if len(self.children) > 1 else Fragment.create() + return true_child, false_child + @classmethod def create( cls, @@ -60,10 +84,11 @@ def create( ) def _render(self) -> Tag: + true_child, false_child = self._get_cond_children() return CondTag( cond_state=str(self.cond), - true_value=self.children[0].render(), - false_value=self.children[1].render(), + true_value=true_child.render(), + false_value=false_child.render(), ) def render(self) -> dict: @@ -72,10 +97,11 @@ def render(self) -> dict: Returns: The dictionary for template of component. """ + true_child, false_child = self._get_cond_children() return { "cond_state": str(self.cond), - "true_value": self.children[0].render(), - "false_value": self.children[1].render(), + "true_value": true_child.render(), + "false_value": false_child.render(), } def add_imports(self) -> ImportDict: diff --git a/packages/reflex-components-core/src/reflex_components_core/core/debounce.py b/packages/reflex-components-core/src/reflex_components_core/core/debounce.py index 500cd1f96a6..b627e87eea5 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/debounce.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/debounce.py @@ -7,6 +7,7 @@ from reflex_base.components.component import Component, field from reflex_base.constants import EventTriggers from reflex_base.event import EventHandler, no_args_event_spec +from reflex_base.utils import format from reflex_base.vars import VarData from reflex_base.vars.base import Var @@ -106,7 +107,15 @@ def create(cls, *children: Component, **props: Any) -> Component: } props.setdefault("custom_attrs", {}).update(other_props, **child.custom_attrs) - # Carry base Component props. + # Carry base Component props. Drop any keys from child.style that + # collide with DebounceInput's own props — Reflex routes unknown child + # kwargs (e.g. ``debounce_timeout`` passed through ``rx.input``) into + # ``style``. + debounce_input_prop_names = { + format.to_camel_case(prop) for prop in cls.get_props() + } + for colliding_key in [k for k in child.style if k in debounce_input_prop_names]: + child.style.pop(colliding_key) props.setdefault("style", {}).update(child.style) if child.class_name is not None: props["class_name"] = f"{props.get('class_name', '')} {child.class_name}" diff --git a/packages/reflex-components-core/src/reflex_components_core/core/match.py b/packages/reflex-components-core/src/reflex_components_core/core/match.py index d062587b04c..9216cfb767b 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/match.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/match.py @@ -3,12 +3,8 @@ import textwrap from typing import Any, cast -from reflex_base.components.component import ( - BaseComponent, - Component, - MemoizationLeaf, - field, -) +from reflex_base.components.component import BaseComponent, Component, field +from reflex_base.components.memoize_helpers import passthrough_children_var from reflex_base.components.tags import Tag from reflex_base.components.tags.match_tag import MatchTag from reflex_base.style import Style @@ -19,9 +15,10 @@ from reflex_base.vars.base import LiteralVar, Var from reflex_components_core.base import Fragment +from reflex_components_core.base.bare import Bare -class Match(MemoizationLeaf): +class Match(Component): """Match cases based on a condition.""" cond: Var[Any] = field(doc="The condition to determine which case to match.") @@ -270,13 +267,39 @@ def _create_match_cond_var_or_component( ) def _render(self) -> Tag: + # Reconstruct match_cases and default from self.children, which may have + # been updated by the compiler walker to include memoized wrappers. + # self.children contains: [case_1_return, case_2_return, ..., default] + # self.match_cases contains the conditions as Vars. + num_cases = len(self.match_cases) + children_var = passthrough_children_var(self.children) + if children_var is not None: + # Auto-memo passthrough body: index into the placeholder array so + # branch JSX stays on the page side. + cases_returns = [Bare.create(children_var[i]) for i in range(num_cases)] + default_return = Bare.create(children_var[num_cases]) + else: + if len(self.children) != num_cases + 1: + msg = ( + f"Match children count mismatch: expected {num_cases + 1} " + f"(cases + default), got {len(self.children)}" + ) + raise ValueError(msg) + + cases_returns = self.children[:num_cases] + default_return = self.children[num_cases] + return MatchTag( cond=str(self.cond), match_cases=[ ([str(cond) for cond in conditions], return_value.render()) - for conditions, return_value in self.match_cases + for (conditions, _), return_value in zip( + self.match_cases, + cases_returns, + strict=True, + ) ], - default=self.default.render(), + default=default_return.render(), ) def render(self) -> dict: diff --git a/packages/reflex-components-core/src/reflex_components_core/core/upload.py b/packages/reflex-components-core/src/reflex_components_core/core/upload.py index 642ef851749..9bbb28de64d 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/upload.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/upload.py @@ -10,9 +10,9 @@ Component, ComponentNamespace, MemoizationLeaf, - StatefulComponent, field, ) +from reflex_base.components.memoize_helpers import get_memoized_event_triggers from reflex_base.constants import Dirs from reflex_base.constants.compiler import Hooks, Imports from reflex_base.environment import environment @@ -358,16 +358,13 @@ def create(cls, *children, **props) -> Component: ), ) - event_triggers = StatefulComponent._get_memoized_event_triggers( + event_triggers = get_memoized_event_triggers( GhostUpload.create( on_drop=upload_props["on_drop"], on_drop_rejected=upload_props["on_drop_rejected"], ) ) - callback_hooks = [] - for trigger_name, (event_var, callback_str) in event_triggers.items(): - upload_props[trigger_name] = event_var - callback_hooks.append(callback_str) + upload_props.update(event_triggers) upload_props = { format.to_camel_case(key): value for key, value in upload_props.items() @@ -392,7 +389,6 @@ def create(cls, *children, **props) -> Component: use_dropzone_arguments._get_all_var_data(), VarData( hooks={ - **dict.fromkeys(callback_hooks, None), f"{left_side} = {right_side};": None, }, imports={ diff --git a/packages/reflex-components-core/src/reflex_components_core/core/window_events.py b/packages/reflex-components-core/src/reflex_components_core/core/window_events.py index debb4c3dc37..ee57bb770c0 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/window_events.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/window_events.py @@ -4,7 +4,7 @@ from typing import Any, cast -from reflex_base.components.component import StatefulComponent, field +from reflex_base.components.component import field from reflex_base.constants.compiler import Hooks from reflex_base.event import EventHandler, key_event, no_args_event_spec from reflex_base.vars.base import Var, VarData @@ -95,9 +95,15 @@ def create(cls, **props) -> WindowEventListener: Returns: The created component. """ + from reflex_base.components.memoize_helpers import get_memoized_event_triggers + real_component = cast("WindowEventListener", super().create(**props)) - hooks = StatefulComponent._fix_event_triggers(real_component) - real_component.hooks = hooks + memo_event_triggers = get_memoized_event_triggers(real_component) + if memo_event_triggers: + real_component.event_triggers = { + **real_component.event_triggers, + **memo_event_triggers, + } return real_component def _exclude_props(self) -> list[str]: diff --git a/packages/reflex-components-core/src/reflex_components_core/el/elements/base.py b/packages/reflex-components-core/src/reflex_components_core/el/elements/base.py index e8b922f7430..476e3768edb 100644 --- a/packages/reflex-components-core/src/reflex_components_core/el/elements/base.py +++ b/packages/reflex-components-core/src/reflex_components_core/el/elements/base.py @@ -3,6 +3,7 @@ from typing import Literal from reflex_base.components.component import field +from reflex_base.constants.compiler import MemoizationMode from reflex_base.vars.base import Var from reflex_components_core.el.element import Element @@ -140,3 +141,29 @@ class BaseHTML(Element): ) title: Var[str] = field(doc="Defines a tooltip for the element.") + + +class RawTextBaseHTML(BaseHTML): + """Base class for HTML elements with raw-text (RCDATA) content models. + + Raw-text elements (````, ``<style>``, ``<textarea>``, ``<script>``, + ``<noscript>``) parse their content as text rather than child markup. A + JSX component child stringifies as ``[object Object]``, so a stateful + ``Bare`` child must be captured inside the element's snapshot body + rather than being independently memoized as a sibling JSX call. + """ + + _memoization_mode = MemoizationMode(recursive=False) + + +class VoidBaseHTML(BaseHTML): + """Base class for void HTML elements (no children allowed). + + Void elements (``<area>``, ``<base>``, ``<br>``, ``<col>``, ``<embed>``, + ``<hr>``, ``<img>``, ``<input>``, ``<link>``, ``<meta>``, ``<source>``, + ``<track>``, ``<wbr>``) cannot have children. A stateful ``Bare`` child + must stay inside the element's snapshot body rather than being + independently memoized into an invalid JSX call. + """ + + _memoization_mode = MemoizationMode(recursive=False) diff --git a/packages/reflex-components-core/src/reflex_components_core/el/elements/forms.py b/packages/reflex-components-core/src/reflex_components_core/el/elements/forms.py index 8a2422081a5..aef4a4ebf14 100644 --- a/packages/reflex-components-core/src/reflex_components_core/el/elements/forms.py +++ b/packages/reflex-components-core/src/reflex_components_core/el/elements/forms.py @@ -29,7 +29,7 @@ from reflex_components_core.el.element import Element -from .base import BaseHTML +from .base import BaseHTML, RawTextBaseHTML, VoidBaseHTML def _handle_submit_js_template( @@ -305,7 +305,7 @@ def _exclude_props(self) -> list[str]: ] -class BaseInput(BaseHTML): +class BaseInput(VoidBaseHTML): """A base class for input elements.""" tag = "input" @@ -652,7 +652,7 @@ class Select(BaseHTML): """ -class Textarea(BaseHTML): +class Textarea(RawTextBaseHTML): """Display the textarea element.""" tag = "textarea" diff --git a/packages/reflex-components-core/src/reflex_components_core/el/elements/inline.py b/packages/reflex-components-core/src/reflex_components_core/el/elements/inline.py index 6a3bf6a3d52..486f87dd609 100644 --- a/packages/reflex-components-core/src/reflex_components_core/el/elements/inline.py +++ b/packages/reflex-components-core/src/reflex_components_core/el/elements/inline.py @@ -5,7 +5,7 @@ from reflex_base.components.component import field from reflex_base.vars.base import Var -from .base import BaseHTML +from .base import BaseHTML, VoidBaseHTML ReferrerPolicy = Literal[ "", @@ -80,7 +80,7 @@ class Bdo(BaseHTML): tag = "bdo" -class Br(BaseHTML): +class Br(VoidBaseHTML): """Display the br element.""" tag = "br" @@ -220,7 +220,7 @@ class U(BaseHTML): tag = "u" -class Wbr(BaseHTML): +class Wbr(VoidBaseHTML): """Display the wbr element.""" tag = "wbr" diff --git a/packages/reflex-components-core/src/reflex_components_core/el/elements/media.py b/packages/reflex-components-core/src/reflex_components_core/el/elements/media.py index 940effdf52b..66b7981ef92 100644 --- a/packages/reflex-components-core/src/reflex_components_core/el/elements/media.py +++ b/packages/reflex-components-core/src/reflex_components_core/el/elements/media.py @@ -8,10 +8,10 @@ from reflex_components_core.el.elements.inline import ReferrerPolicy -from .base import BaseHTML +from .base import BaseHTML, RawTextBaseHTML, VoidBaseHTML -class Area(BaseHTML): +class Area(VoidBaseHTML): """Display the area element.""" tag = "area" @@ -78,7 +78,7 @@ class Audio(BaseHTML): ImageLoading = Literal["eager", "lazy"] -class Img(BaseHTML): +class Img(VoidBaseHTML): """Display the img element.""" tag = "img" @@ -135,7 +135,7 @@ class Map(BaseHTML): ) -class Track(BaseHTML): +class Track(VoidBaseHTML): """Display the track element.""" tag = "track" @@ -187,7 +187,7 @@ class Video(BaseHTML): src: Var[str] = field(doc="URL of the video to play") -class Embed(BaseHTML): +class Embed(VoidBaseHTML): """Display the embed element.""" tag = "embed" @@ -251,7 +251,7 @@ class Portal(BaseHTML): tag = "portal" -class Source(BaseHTML): +class Source(VoidBaseHTML): """Display the source element.""" tag = "source" @@ -870,13 +870,13 @@ class MPath(BaseHTML): href: Var[str] = field(doc="Reference to a path element.") -class Desc(BaseHTML): +class Desc(RawTextBaseHTML): """The SVG desc component for descriptions.""" tag = "desc" -class Title(BaseHTML): +class Title(RawTextBaseHTML): """The SVG title component for titles.""" tag = "title" @@ -888,7 +888,7 @@ class Metadata(BaseHTML): tag = "metadata" -class Script(BaseHTML): +class Script(RawTextBaseHTML): """The SVG script component for scripts.""" tag = "script" @@ -900,7 +900,7 @@ class Script(BaseHTML): crossorigin: Var[str] = field(doc="CORS settings for the script.") -class SvgStyle(BaseHTML): +class SvgStyle(RawTextBaseHTML): """The SVG style component for stylesheets.""" tag = "style" diff --git a/packages/reflex-components-core/src/reflex_components_core/el/elements/metadata.py b/packages/reflex-components-core/src/reflex_components_core/el/elements/metadata.py index a05a9c0d2bc..f6211538c8f 100644 --- a/packages/reflex-components-core/src/reflex_components_core/el/elements/metadata.py +++ b/packages/reflex-components-core/src/reflex_components_core/el/elements/metadata.py @@ -1,16 +1,16 @@ """Metadata classes.""" -from reflex_base.components.component import field +from reflex_base.components.component import MemoizationLeaf, field from reflex_base.vars.base import Var from reflex_components_core.el.element import Element from reflex_components_core.el.elements.inline import ReferrerPolicy from reflex_components_core.el.elements.media import CrossOrigin -from .base import BaseHTML +from .base import BaseHTML, VoidBaseHTML -class Base(BaseHTML): +class Base(VoidBaseHTML): """Display the base element.""" tag = "base" @@ -25,7 +25,7 @@ class Head(BaseHTML): tag = "head" -class Link(BaseHTML): +class Link(VoidBaseHTML): """Display the link element.""" tag = "link" @@ -61,7 +61,7 @@ class Link(BaseHTML): type: Var[str] = field(doc="Specifies the MIME type of the linked document") -class Meta(BaseHTML): # Inherits common attributes from BaseHTML +class Meta(VoidBaseHTML): # Inherits common attributes from BaseHTML """Display the meta element.""" tag = "meta" # The HTML tag for this element is <meta> @@ -81,14 +81,14 @@ class Meta(BaseHTML): # Inherits common attributes from BaseHTML property: Var[str] = field(doc="The type of metadata value.") -class Title(Element): +class Title(MemoizationLeaf, Element): """Display the title element.""" tag = "title" # Had to be named with an underscore so it doesn't conflict with reflex.style Style in pyi -class StyleEl(Element): +class StyleEl(MemoizationLeaf, Element): """Display the style element.""" tag = "style" diff --git a/packages/reflex-components-core/src/reflex_components_core/el/elements/scripts.py b/packages/reflex-components-core/src/reflex_components_core/el/elements/scripts.py index caa68031ec8..8bc4341b6cb 100644 --- a/packages/reflex-components-core/src/reflex_components_core/el/elements/scripts.py +++ b/packages/reflex-components-core/src/reflex_components_core/el/elements/scripts.py @@ -6,7 +6,7 @@ from reflex_components_core.el.elements.inline import ReferrerPolicy from reflex_components_core.el.elements.media import CrossOrigin -from .base import BaseHTML +from .base import BaseHTML, RawTextBaseHTML class Canvas(BaseHTML): @@ -15,13 +15,13 @@ class Canvas(BaseHTML): tag = "canvas" -class Noscript(BaseHTML): +class Noscript(RawTextBaseHTML): """Display the noscript element.""" tag = "noscript" -class Script(BaseHTML): +class Script(RawTextBaseHTML): """Display the script element.""" tag = "script" diff --git a/packages/reflex-components-core/src/reflex_components_core/el/elements/tables.py b/packages/reflex-components-core/src/reflex_components_core/el/elements/tables.py index 37e686751cc..3aab7bd0d2d 100644 --- a/packages/reflex-components-core/src/reflex_components_core/el/elements/tables.py +++ b/packages/reflex-components-core/src/reflex_components_core/el/elements/tables.py @@ -5,7 +5,7 @@ from reflex_base.components.component import field from reflex_base.vars.base import Var -from .base import BaseHTML +from .base import BaseHTML, VoidBaseHTML class Caption(BaseHTML): @@ -14,7 +14,7 @@ class Caption(BaseHTML): tag = "caption" -class Col(BaseHTML): +class Col(VoidBaseHTML): """Display the col element.""" tag = "col" diff --git a/packages/reflex-components-core/src/reflex_components_core/el/elements/typography.py b/packages/reflex-components-core/src/reflex_components_core/el/elements/typography.py index 7ac0b737235..b891201495e 100644 --- a/packages/reflex-components-core/src/reflex_components_core/el/elements/typography.py +++ b/packages/reflex-components-core/src/reflex_components_core/el/elements/typography.py @@ -5,7 +5,7 @@ from reflex_base.components.component import field from reflex_base.vars.base import Var -from .base import BaseHTML +from .base import BaseHTML, VoidBaseHTML class Blockquote(BaseHTML): @@ -52,7 +52,7 @@ class Figure(BaseHTML): tag = "figure" -class Hr(BaseHTML): +class Hr(VoidBaseHTML): """Display the hr element.""" tag = "hr" diff --git a/packages/reflex-components-radix/src/reflex_components_radix/plugin.py b/packages/reflex-components-radix/src/reflex_components_radix/plugin.py new file mode 100644 index 00000000000..150183a17ad --- /dev/null +++ b/packages/reflex-components-radix/src/reflex_components_radix/plugin.py @@ -0,0 +1,115 @@ +"""Plugin support for opt-in Radix Themes integration.""" + +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING, Any + +from reflex_base.components.component import BaseComponent, Component +from reflex_base.components.dynamic import bundle_library +from reflex_base.plugins.base import Plugin +from reflex_base.utils import console + +from reflex_components_radix import themes +from reflex_components_radix.themes.base import RadixThemesComponent + +if TYPE_CHECKING: + from reflex_base.plugins.compiler import PageContext + + +RADIX_THEMES_STYLESHEET = "@radix-ui/themes/styles.css" +RADIX_THEMES_PACKAGE = "@radix-ui/themes@3.3.0" +_DEPRECATION_VERSION = "0.9.0" +_REMOVAL_VERSION = "1.0" + + +@dataclasses.dataclass +class RadixThemesPlugin(Plugin): + """Opt-in plugin for Radix Themes assets and app-level wrapping.""" + + theme: Component | None = dataclasses.field( + default_factory=lambda: themes.theme(accent_color="blue") + ) + enabled: bool = dataclasses.field(default=True, repr=False) + _explicit: bool = dataclasses.field(default=True, repr=False) + _app_theme_warning_emitted: bool = dataclasses.field( + default=False, init=False, repr=False + ) + + @classmethod + def create_implicit(cls) -> RadixThemesPlugin: + """Create a compile-local plugin that starts disabled. + + Returns: + The disabled compile-local plugin. + """ + return cls(enabled=False, _explicit=False) + + def get_stylesheet_paths(self, **context: Any) -> tuple[str, ...]: + """Return the Radix Themes stylesheet when enabled.""" + return (RADIX_THEMES_STYLESHEET,) if self.enabled else () + + def get_frontend_dependencies(self, **context: Any) -> tuple[str, ...]: + """Return the Radix Themes package when enabled.""" + return (RADIX_THEMES_PACKAGE,) if self.enabled else () + + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: Any, + in_prop_tree: bool = False, + ) -> None: + """Auto-enable the plugin when a Radix Themes component is compiled.""" + if self.enabled or not isinstance(comp, RadixThemesComponent): + return + + self.enabled = True + bundle_library(RADIX_THEMES_PACKAGE) + if not self._explicit and not self._app_theme_warning_emitted: + console.deprecate( + feature_name="Implicit Radix Themes enablement", + reason=( + "a Radix Themes component was detected, which enables the full " + "Radix CSS bundle. Configure `rx.plugins.RadixThemesPlugin()` in " + "`rxconfig.py` to make this explicit, or remove Radix components " + "to avoid loading the stylesheet" + ), + deprecation_version=_DEPRECATION_VERSION, + removal_version=_REMOVAL_VERSION, + ) + + def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + """Inject the app-level theme wrapper when Radix Themes is active.""" + if self.enabled and self.theme is not None: + page_ctx.app_wrap_components[20, "Theme"] = self.theme + + def get_theme(self) -> Component | None: + """Return the effective theme component for the active compile.""" + return self.theme if self.enabled else None + + def apply_app_theme(self, theme: Component) -> None: + """Handle deprecated ``App(theme=...)`` compatibility.""" + console.deprecate( + feature_name="App(theme=...)", + reason=( + "configure `rx.plugins.RadixThemesPlugin(theme=...)` in " + "`rxconfig.py` instead" + ), + deprecation_version=_DEPRECATION_VERSION, + removal_version=_REMOVAL_VERSION, + ) + self._app_theme_warning_emitted = True + + if self._explicit: + return + + self.enabled = True + self.theme = theme diff --git a/pyi_hashes.json b/pyi_hashes.json index b0bf954a0a2..db5a99505e0 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -8,7 +8,7 @@ "packages/reflex-components-core/src/reflex_components_core/base/document.pyi": "a2e67a9814dc61853ca2299d9d9c698d", "packages/reflex-components-core/src/reflex_components_core/base/error_boundary.pyi": "59170074a1a228ce58685f3f207954f2", "packages/reflex-components-core/src/reflex_components_core/base/fragment.pyi": "e4cbfc46eabb904596be4372392add35", - "packages/reflex-components-core/src/reflex_components_core/base/link.pyi": "629a483c570b04ca3d83ecdc53914770", + "packages/reflex-components-core/src/reflex_components_core/base/link.pyi": "005866cf4d1cc8ac7693ed6baeca2289", "packages/reflex-components-core/src/reflex_components_core/base/meta.pyi": "0cfa2d8c52321ce7440e887d03007d5b", "packages/reflex-components-core/src/reflex_components_core/base/script.pyi": "bfc7fb609b822f597d1141595f8090fe", "packages/reflex-components-core/src/reflex_components_core/base/strict_mode.pyi": "8ee129808abb4389cbd77a1736190eae", @@ -16,26 +16,26 @@ "packages/reflex-components-core/src/reflex_components_core/core/auto_scroll.pyi": "918dfad4d5925addd0f741e754b3b076", "packages/reflex-components-core/src/reflex_components_core/core/banner.pyi": "6040fbada9b96c55637a9c8cc21a5e10", "packages/reflex-components-core/src/reflex_components_core/core/clipboard.pyi": "e3950e0963a6d04299ff58294687e407", - "packages/reflex-components-core/src/reflex_components_core/core/debounce.pyi": "dd221754c5e076a3a833c8584da72dc5", + "packages/reflex-components-core/src/reflex_components_core/core/debounce.pyi": "58138b5f1d5901839729d839620ea4da", "packages/reflex-components-core/src/reflex_components_core/core/helmet.pyi": "7fd81a99bde5b0ff94bb52523597fd5c", "packages/reflex-components-core/src/reflex_components_core/core/html.pyi": "753d6ae315369530dad450ed643f5be6", "packages/reflex-components-core/src/reflex_components_core/core/sticky.pyi": "ba60a7d9cba75b27a1133bd63a9fbd59", - "packages/reflex-components-core/src/reflex_components_core/core/upload.pyi": "48bccd1cf1cedaf24503e7dbd2c0b02b", - "packages/reflex-components-core/src/reflex_components_core/core/window_events.pyi": "cab827931770be082cd1598a9908abbc", + "packages/reflex-components-core/src/reflex_components_core/core/upload.pyi": "2dd6ba6e3a4d61fc1d79eb582a7cc548", + "packages/reflex-components-core/src/reflex_components_core/core/window_events.pyi": "5e1dcb1130bc8af282783fae329ae6a6", "packages/reflex-components-core/src/reflex_components_core/datadisplay/__init__.pyi": "c96fed4da42a13576d64f84e3c7cb25c", "packages/reflex-components-core/src/reflex_components_core/el/__init__.pyi": "f09129ddefb57ab4c7769c86dc9a3153", "packages/reflex-components-core/src/reflex_components_core/el/element.pyi": "ff68d843c5987d3f0d773a6367eb9c63", "packages/reflex-components-core/src/reflex_components_core/el/elements/__init__.pyi": "e6c845f2f29eb079697a2e31b0c2f23a", - "packages/reflex-components-core/src/reflex_components_core/el/elements/base.pyi": "d2500a39e6e532bb90c83438343905bf", - "packages/reflex-components-core/src/reflex_components_core/el/elements/forms.pyi": "ca840a20c8e1c1f5335fb815a25b6c32", - "packages/reflex-components-core/src/reflex_components_core/el/elements/inline.pyi": "c38a432d1fd0c3208c4fc3a546c67e4d", - "packages/reflex-components-core/src/reflex_components_core/el/elements/media.pyi": "b794f4f4f7ad17c6939d5526b9c63397", - "packages/reflex-components-core/src/reflex_components_core/el/elements/metadata.pyi": "f9e51feebda79fb063bc264a235df0c3", + "packages/reflex-components-core/src/reflex_components_core/el/elements/base.pyi": "a3ef8bcb5fe8e4bfb22a8f6d714611b8", + "packages/reflex-components-core/src/reflex_components_core/el/elements/forms.pyi": "ab968cdfc51968d6c0c4e8a884c4f246", + "packages/reflex-components-core/src/reflex_components_core/el/elements/inline.pyi": "9c1432e70e6b9349f44df04a244a4303", + "packages/reflex-components-core/src/reflex_components_core/el/elements/media.pyi": "f51120c31a1a8b79da9ecf58f19005b9", + "packages/reflex-components-core/src/reflex_components_core/el/elements/metadata.pyi": "73d19f3d9e389447ad8bbb68e1b7d1c9", "packages/reflex-components-core/src/reflex_components_core/el/elements/other.pyi": "c86abf00384b5f15725a0daf2533848d", - "packages/reflex-components-core/src/reflex_components_core/el/elements/scripts.pyi": "222176bffc14191018fd0e3af3741aff", + "packages/reflex-components-core/src/reflex_components_core/el/elements/scripts.pyi": "903432e316a781b342f2b8d334952da1", "packages/reflex-components-core/src/reflex_components_core/el/elements/sectioning.pyi": "fbbe0bf222d4196c32c88d05cb077997", - "packages/reflex-components-core/src/reflex_components_core/el/elements/tables.pyi": "cba93678248925c981935a251379aa7c", - "packages/reflex-components-core/src/reflex_components_core/el/elements/typography.pyi": "4abedc6f98f6d54194ff9e7f1f76314e", + "packages/reflex-components-core/src/reflex_components_core/el/elements/tables.pyi": "93a69aab9a6f519e3f293d439a39786b", + "packages/reflex-components-core/src/reflex_components_core/el/elements/typography.pyi": "2b434f2231d6f21b12d32995ac185e79", "packages/reflex-components-core/src/reflex_components_core/react_router/dom.pyi": "1074a512195ae23d479c4a2d553954e1", "packages/reflex-components-dataeditor/src/reflex_components_dataeditor/dataeditor.pyi": "8e379fa038c7c6c0672639eb5902934d", "packages/reflex-components-gridjs/src/reflex_components_gridjs/datatable.pyi": "d2dc211d707c402eb24678a4cba945f7", @@ -114,11 +114,11 @@ "packages/reflex-components-recharts/src/reflex_components_recharts/__init__.pyi": "7b8b69840a3637c1f1cac45ba815cccf", "packages/reflex-components-recharts/src/reflex_components_recharts/cartesian.pyi": "277bbf09d72e0c450241f0b7d39ebb60", "packages/reflex-components-recharts/src/reflex_components_recharts/charts.pyi": "be20d1d71c3b16f7e973a0329c3d81d6", - "packages/reflex-components-recharts/src/reflex_components_recharts/general.pyi": "c051ab3a26c23107043e203b060e1412", + "packages/reflex-components-recharts/src/reflex_components_recharts/general.pyi": "5a1a479924ad6184abafe4d796cb04c5", "packages/reflex-components-recharts/src/reflex_components_recharts/polar.pyi": "1979bb6c22bb7a0d3342b2d63fb19d74", - "packages/reflex-components-recharts/src/reflex_components_recharts/recharts.pyi": "234407dbd466bf9c87d75ce979ab0e2d", + "packages/reflex-components-recharts/src/reflex_components_recharts/recharts.pyi": "c5288f311fe37b23539518ba2a3d4482", "packages/reflex-components-sonner/src/reflex_components_sonner/toast.pyi": "2c5fadcc014056f041cd4d916137d9e7", "reflex/__init__.pyi": "3a9bb8544cbc338ffaf0a5927d9156df", "reflex/components/__init__.pyi": "f39a2af77f438fa243c58c965f19d42e", - "reflex/experimental/memo.pyi": "2c119a0dfea362dcd8193786363cbc02" + "reflex/experimental/memo.pyi": "9946d9b757f7cef5f53d599194d6e50e" } diff --git a/reflex/app.py b/reflex/app.py index 240da421ef5..071b8600e8e 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -3,7 +3,6 @@ from __future__ import annotations import asyncio -import concurrent.futures import contextlib import copy import dataclasses @@ -15,25 +14,23 @@ import time import traceback import urllib.parse -from collections.abc import AsyncIterator, Callable, Coroutine, Mapping, Sequence +from collections.abc import ( + AsyncIterator, + Callable, + Collection, + Coroutine, + Mapping, + Sequence, +) from contextvars import Token -from datetime import datetime -from itertools import chain -from pathlib import Path -from timeit import default_timer as timer from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, ParamSpec, overload +from typing import TYPE_CHECKING, Any, overload from reflex_base import constants -from reflex_base.components.component import ( - CUSTOM_COMPONENTS, - Component, - ComponentStyle, - evaluate_style_namespaces, -) +from reflex_base.components.component import Component, ComponentStyle from reflex_base.config import get_config from reflex_base.context.base import BaseContext -from reflex_base.environment import ExecutorType, environment +from reflex_base.environment import environment from reflex_base.event import ( _EVENT_FIELDS, Event, @@ -48,10 +45,8 @@ from reflex_base.utils import console from reflex_base.utils.imports import ImportVar from reflex_base.utils.types import ASGIApp, Message, Receive, Scope, Send -from reflex_components_core.base.app_wrap import AppWrap from reflex_components_core.base.error_boundary import ErrorBoundary from reflex_components_core.base.fragment import Fragment -from reflex_components_core.base.strict_mode import StrictMode from reflex_components_core.core.banner import ( backend_disabled, connection_pulser, @@ -59,9 +54,7 @@ ) from reflex_components_core.core.breakpoints import set_breakpoints from reflex_components_core.core.sticky import sticky -from reflex_components_radix import themes from reflex_components_sonner.toast import toast -from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from socketio import ASGIApp as EngineIOApp from socketio import AsyncNamespace, AsyncServer from starlette.applications import Starlette @@ -76,13 +69,7 @@ from reflex.admin import AdminDash from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin from reflex.compiler import compiler -from reflex.compiler import utils as compiler_utils -from reflex.compiler.compiler import ( - ExecutorSafeFunctions, - compile_theme, - readable_name_from_component, -) -from reflex.experimental.memo import EXPERIMENTAL_MEMOS +from reflex.compiler.compiler import readable_name_from_component from reflex.istate.manager import StateManager, StateModificationContext from reflex.istate.manager.token import BaseStateToken from reflex.page import DECORATED_PAGES @@ -97,17 +84,8 @@ State, StateUpdate, all_base_state_classes, - code_uses_state_contexts, -) -from reflex.utils import ( - codespaces, - exceptions, - format, - frontend_skeleton, - js_runtimes, - path_ops, - prerequisites, ) +from reflex.utils import codespaces, exceptions, format, js_runtimes, prerequisites from reflex.utils.exec import ( get_compile_context, is_prod_mode, @@ -207,10 +185,10 @@ def extra_overlay_function() -> Component | None: def default_overlay_component() -> Component: - """Default overlay_component attribute for App. + """Default overlay component included in the app wraps. Returns: - The default overlay_component, which is a connection_modal. + The default overlay component, which is a connection banner/toaster set. """ from reflex_base.components.component import memo @@ -254,12 +232,12 @@ class UnevaluatedPage: component: Component | ComponentCallable route: str - title: Var | str | None - description: Var | str | None - image: str - on_load: EventType[()] | None - meta: Sequence[Mapping[str, Any] | Component] - context: Mapping[str, Any] + title: Var | str | None = None + description: Var | str | None = None + image: str = "" + on_load: EventType[()] | None = None + meta: Sequence[Mapping[str, Any] | Component] = () + context: Mapping[str, Any] = dataclasses.field(default_factory=dict) def merged_with(self, other: UnevaluatedPage) -> UnevaluatedPage: """Merge the other page into this one. @@ -281,9 +259,6 @@ def merged_with(self, other: UnevaluatedPage) -> UnevaluatedPage: ) -P = ParamSpec("P") - - @dataclasses.dataclass() class App(MiddlewareMixin, LifespanMixin): """The main Reflex app that encapsulates the backend and frontend. @@ -300,17 +275,16 @@ class App(MiddlewareMixin, LifespanMixin): app = rx.App( # Set global level style. style={...}, - # Set the top level theme. + # Deprecated legacy shortcut for the Radix Themes plugin. theme=rx.theme(accent_color="blue"), ) ``` Attributes: - theme: The global [theme](https://reflex.dev/docs/styling/theming/#theme) for the entire app. + theme: Deprecated legacy shortcut for configuring the app-level Radix theme. style: The [global style](https://reflex.dev/docs/styling/overview/#global-styles}) for the app. stylesheets: A list of URLs to [stylesheets](https://reflex.dev/docs/styling/custom-stylesheets/) to include in the app. reset_style: Whether to include CSS reset for margin and padding. Defaults to True. - overlay_component: A component that is present on every page. Defaults to the Connection Error banner. app_wraps: App wraps to be applied to the whole app. Expected to be a dictionary of (order, name) to a function that takes whether the state is enabled and optionally returns a component. extra_app_wraps: Extra app wraps to be applied to the whole app. head_components: Components to add to the head of every page. @@ -325,9 +299,7 @@ class App(MiddlewareMixin, LifespanMixin): api_transformer: Transform the ASGI app before running it. """ - theme: Component | None = dataclasses.field( - default_factory=lambda: themes.theme(accent_color="blue") - ) + theme: Component | None = dataclasses.field(default=None) style: ComponentStyle = dataclasses.field(default_factory=dict) @@ -947,7 +919,10 @@ def _compile_page(self, route: str, save_page: bool = True): """ n_states_before = len(all_base_state_classes) component = compiler.compile_unevaluated_page( - route, self._unevaluated_pages[route], self.style, self.theme + route, + self._unevaluated_pages[route], + self.style, + self.theme, ) # Indicate that evaluating this page creates one or more state classes. @@ -1059,7 +1034,10 @@ def _setup_admin_dash(self): admin.mount_to(self._api) - def _get_frontend_packages(self, imports: dict[str, set[ImportVar]]): + def _get_frontend_packages( + self, + imports: Mapping[str, Collection[ImportVar]], + ) -> None: """Gets the frontend packages to be installed and filters out the unnecessary ones. Args: @@ -1133,16 +1111,6 @@ def _should_compile(self) -> bool: # By default, compile the app. return True - def _add_overlay_to_component( - self, component: Component, overlay_component: Component - ) -> Component: - children = component.children - - if children[0] == overlay_component: - return component - - return Fragment.create(overlay_component, *children) - def _setup_sticky_badge(self): """Add the sticky badge to the app.""" from reflex_base.components.component import memo @@ -1211,397 +1179,13 @@ def _compile( ReflexRuntimeError: When any page uses state, but no rx.State subclass is defined. FileNotFoundError: When a plugin requires a file that does not exist. """ - from reflex_base.utils.exceptions import ReflexRuntimeError - - self._apply_decorated_pages() - - self._pages = {} - - def get_compilation_time() -> str: - return str(datetime.now().time()).split(".")[0] - - should_compile = self._should_compile() - backend_dir = prerequisites.get_backend_dir() - if not dry_run and not should_compile and backend_dir.exists(): - stateful_pages_marker = backend_dir / constants.Dirs.STATEFUL_PAGES - if stateful_pages_marker.exists(): - with stateful_pages_marker.open("r") as f: - stateful_pages = json.load(f) - for route in stateful_pages: - console.debug(f"BE Evaluating stateful page: {route}") - self._compile_page(route, save_page=False) - self._add_optional_endpoints() - return - - # Render a default 404 page if the user didn't supply one - if constants.Page404.SLUG not in self._unevaluated_pages: - self.add_page(route=constants.Page404.SLUG) - - # Fix up the style. - self.style = evaluate_style_namespaces(self.style) - - # Add the app wrappers. - app_wrappers: dict[tuple[int, str], Component] = { - # Default app wrap component renders {children} - (0, "AppWrap"): AppWrap.create() - } - - if self.theme is not None: - # If a theme component was provided, wrap the app with it - app_wrappers[20, "Theme"] = self.theme - - # Get the env mode. - config = get_config() - - if config.react_strict_mode: - app_wrappers[200, "StrictMode"] = StrictMode.create() - - if not should_compile and not dry_run: - with console.timing("Evaluate Pages (Backend)"): - for route in self._unevaluated_pages: - console.debug(f"Evaluating page: {route}") - self._compile_page(route, save_page=should_compile) - - # Save the pages which created new states at eval time. - self._write_stateful_pages_marker() - - # Add the optional endpoints (_upload) - self._add_optional_endpoints() - - return - - # Create a progress bar. - progress = ( - Progress( - *Progress.get_default_columns()[:-1], - MofNCompleteColumn(), - TimeElapsedColumn(), - ) - if use_rich - else console.PoorProgress() - ) - - # try to be somewhat accurate - but still not 100% - adhoc_steps_without_executor = 8 - fixed_pages_within_executor = 4 - plugin_count = len(config.plugins) - progress.start() - task = progress.add_task( - f"[{get_compilation_time()}] Compiling:", - total=len(self._unevaluated_pages) - + ((len(self._unevaluated_pages) + len(self._pages)) * 3) - + fixed_pages_within_executor - + adhoc_steps_without_executor - + plugin_count, - ) - - with console.timing("Evaluate Pages (Frontend)"): - performance_metrics: list[tuple[str, float]] = [] - for route in self._unevaluated_pages: - console.debug(f"Evaluating page: {route}") - start = timer() - self._compile_page(route, save_page=should_compile) - end = timer() - performance_metrics.append((route, end - start)) - progress.advance(task) - console.debug( - "Slowest pages:\n" - + "\n".join( - f"{route}: {time * 1000:.1f}ms" - for route, time in sorted( - performance_metrics, key=operator.itemgetter(1), reverse=True - )[:10] - ) - ) - # Save the pages which created new states at eval time. - self._write_stateful_pages_marker() - - # Add the optional endpoints (_upload) - self._add_optional_endpoints() - - self._validate_var_dependencies() - - if config.show_built_with_reflex is None: - if ( - get_compile_context() == constants.CompileContext.DEPLOY - and prerequisites.get_user_tier() in ["pro", "team", "enterprise"] - ): - config.show_built_with_reflex = False - else: - config.show_built_with_reflex = True - - if is_prod_mode() and config.show_built_with_reflex: - self._setup_sticky_badge() - - progress.advance(task) - - # Store the compile results. - compile_results: list[tuple[str, str]] = [] - - progress.advance(task) - - # Reinitialize vite config in case runtime options have changed. - compile_results.append(( - constants.ReactRouter.VITE_CONFIG_FILE, - frontend_skeleton._compile_vite_config(config), - )) - progress.advance(task) - - # Track imports found. - all_imports = {} - - if (toaster := self.toaster) is not None: - from reflex_base.components.component import memo - - @memo - def memoized_toast_provider(): - return toaster - - toast_provider = Fragment.create(memoized_toast_provider()) - - app_wrappers[44, "ToasterProvider"] = toast_provider - - # Add the app wraps to the app. - for key, app_wrap in chain( - self.app_wraps.items(), self.extra_app_wraps.items() - ): - # If the app wrap is a callable, generate the component - component = app_wrap(self._state is not None) - if component is not None: - app_wrappers[key] = component - - # Compile custom components. - ( - memo_components_output, - memo_components_result, - memo_components_imports, - ) = compiler.compile_memo_components( - dict.fromkeys(CUSTOM_COMPONENTS.values()), - tuple(EXPERIMENTAL_MEMOS.values()), - ) - compile_results.append((memo_components_output, memo_components_result)) - all_imports.update(memo_components_imports) - progress.advance(task) - - with console.timing("Collect all imports and app wraps"): - # This has to happen before compiling stateful components as that - # prevents recursive functions from reaching all components. - for component in self._pages.values(): - # Add component._get_all_imports() to all_imports. - all_imports.update(component._get_all_imports()) - - # Add the app wrappers from this component. - app_wrappers.update(component._get_all_app_wrap_components()) - - progress.advance(task) - - # Perform auto-memoization of stateful components. - with console.timing("Auto-memoize StatefulComponents"): - ( - stateful_components_path, - stateful_components_code, - page_components, - ) = compiler.compile_stateful_components( - self._pages.values(), - progress_function=lambda task=task: progress.advance(task), - ) - progress.advance(task) - - # Catch "static" apps (that do not define a rx.State subclass) which are trying to access rx.State. - if code_uses_state_contexts(stateful_components_code) and self._state is None: - msg = ( - "To access rx.State in frontend components, at least one " - "subclass of rx.State must be defined in the app." - ) - raise ReflexRuntimeError(msg) - compile_results.append((stateful_components_path, stateful_components_code)) - - progress.advance(task) - - # Compile the root document before fork. - compile_results.append( - compiler.compile_document_root( - self.head_components, - html_lang=self.html_lang, - html_custom_attrs=( - {"suppressHydrationWarning": True, **self.html_custom_attrs} - if self.html_custom_attrs - else {"suppressHydrationWarning": True} - ), - ) - ) - - progress.advance(task) - - # Copy the assets. - assets_src = Path.cwd() / constants.Dirs.APP_ASSETS - if assets_src.is_dir() and not dry_run: - with console.timing("Copy assets"): - path_ops.update_directory_tree( - src=assets_src, - dest=( - Path.cwd() / prerequisites.get_web_dir() / constants.Dirs.PUBLIC - ), - ) - - executor = ExecutorType.get_executor_from_environment() - - for route, component in zip(self._pages, page_components, strict=True): - ExecutorSafeFunctions.COMPONENTS[route] = component - - modify_files_tasks: list[tuple[str, str, Callable[[str], str]]] = [] - - with console.timing("Compile to Javascript"), executor as executor: - result_futures: list[ - concurrent.futures.Future[ - list[tuple[str, str]] | tuple[str, str] | None - ] - ] = [] - - def _submit_work( - fn: Callable[P, list[tuple[str, str]] | tuple[str, str] | None], - *args: P.args, - **kwargs: P.kwargs, - ): - f = executor.submit(fn, *args, **kwargs) - f.add_done_callback(lambda _: progress.advance(task)) - result_futures.append(f) - - # Compile the pre-compiled pages. - for route in self._pages: - _submit_work( - ExecutorSafeFunctions.compile_page, - route, - ) - - # Compile the root stylesheet with base styles. - _submit_work( - compiler.compile_root_stylesheet, self.stylesheets, self.reset_style - ) - - # Compile the theme. - _submit_work(compile_theme, self.style) - - def _submit_work_without_advancing( - fn: Callable[P, list[tuple[str, str]] | tuple[str, str] | None], - *args: P.args, - **kwargs: P.kwargs, - ): - f = executor.submit(fn, *args, **kwargs) - result_futures.append(f) - - for plugin in config.plugins: - plugin.pre_compile( - add_save_task=_submit_work_without_advancing, - add_modify_task=( - lambda *args, plugin=plugin: modify_files_tasks.append(( - plugin.__class__.__module__ + plugin.__class__.__name__, - *args, - )) - ), - unevaluated_pages=list(self._unevaluated_pages.values()), - ) - - # Wait for all compilation tasks to complete. - for future in concurrent.futures.as_completed(result_futures): - if (result := future.result()) is not None: - if isinstance(result, list): - compile_results.extend(result) - else: - compile_results.append(result) - - progress.advance(task, advance=len(config.plugins)) - - app_root = self._app_root(app_wrappers=app_wrappers) - - # Get imports from AppWrap components. - all_imports.update(app_root._get_all_imports()) - - progress.advance(task) - - # Compile the contexts. - compile_results.append( - compiler.compile_contexts(self._state, self.theme), - ) - if self.theme is not None: - # Fix #2992 by removing the top-level appearance prop - self.theme.appearance = None # pyright: ignore[reportAttributeAccessIssue] - progress.advance(task) - - # Compile the app root. - compile_results.append( - compiler.compile_app(app_root), - ) - progress.advance(task) - - progress.stop() - - if dry_run: - return - - # Install frontend packages. - with console.timing("Install Frontend Packages"): - self._get_frontend_packages(all_imports) - - # Setup the react-router.config.js - frontend_skeleton.update_react_router_config( + compiler.compile_app( + self, prerender_routes=prerender_routes, + dry_run=dry_run, + use_rich=use_rich, ) - if is_prod_mode(): - # Empty the .web pages directory. - compiler.purge_web_pages_dir() - else: - # In dev mode, delete removed pages and update existing pages. - keep_files = [Path(output_path) for output_path, _ in compile_results] - for p in Path( - prerequisites.get_web_dir() - / constants.Dirs.PAGES - / constants.Dirs.ROUTES - ).rglob("*"): - if p.is_file() and p not in keep_files: - # Remove pages that are no longer in the app. - p.unlink() - - output_mapping: dict[Path, str] = {} - for output_path, code in compile_results: - path = compiler_utils.resolve_path_of_web_dir(output_path) - if path in output_mapping: - console.warn( - f"Path {path} has two different outputs. The first one will be used." - ) - else: - output_mapping[path] = code - - for plugin in config.plugins: - for static_file_path, content in plugin.get_static_assets(): - path = compiler_utils.resolve_path_of_web_dir(static_file_path) - if path in output_mapping: - console.warn( - f"Plugin {plugin.__class__.__name__} is trying to write to {path} but it already exists. The plugin file will be ignored." - ) - else: - output_mapping[path] = ( - content.decode("utf-8") - if isinstance(content, bytes) - else content - ) - - for plugin_name, file_path, modify_fn in modify_files_tasks: - path = compiler_utils.resolve_path_of_web_dir(file_path) - file_content = output_mapping.get(path) - if file_content is None: - if path.exists(): - file_content = path.read_text() - else: - msg = f"Plugin {plugin_name} is trying to modify {path} but it does not exist." - raise FileNotFoundError(msg) - output_mapping[path] = modify_fn(file_content) - - with console.timing("Write to Disk"): - for output_path, code in output_mapping.items(): - compiler_utils.write_file(output_path, code) - def _write_stateful_pages_marker(self): """Write list of routes that create dynamic states for the backend to use later.""" if self._state is not None: diff --git a/reflex/app_mixins/lifespan.py b/reflex/app_mixins/lifespan.py index bac00517bb0..13642e7e458 100644 --- a/reflex/app_mixins/lifespan.py +++ b/reflex/app_mixins/lifespan.py @@ -134,7 +134,7 @@ async def _run_lifespan_tasks(self, app: Starlette): # Flush any pending writes from the state manager. try: state_manager = self.state_manager # pyright: ignore[reportAttributeAccessIssue] - except AttributeError: + except (AttributeError, ValueError): pass else: await state_manager.close() diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index feec49ee69a..39ff4931a9e 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import sys from collections.abc import Callable, Iterable, Sequence from inspect import getmodule @@ -10,34 +11,52 @@ from reflex_base import constants from reflex_base.components.component import ( + CUSTOM_COMPONENTS, BaseComponent, Component, ComponentStyle, CustomComponent, - StatefulComponent, + evaluate_style_namespaces, ) from reflex_base.config import get_config from reflex_base.constants.compiler import PageNames, ResetStylesheet from reflex_base.constants.state import FIELD_MARKER from reflex_base.environment import environment +from reflex_base.plugins import CompileContext, CompilerHooks, PageContext, Plugin from reflex_base.style import SYSTEM_COLOR_MODE from reflex_base.utils.exceptions import ReflexError from reflex_base.utils.format import to_title_case -from reflex_base.utils.imports import ImportVar, ParsedImportDict +from reflex_base.utils.imports import ImportVar from reflex_base.vars.base import LiteralVar, Var +from reflex_components_core.base.app_wrap import AppWrap from reflex_components_core.base.fragment import Fragment +from reflex_components_radix.plugin import RadixThemesPlugin +from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from reflex.compiler import templates, utils +from reflex.compiler.plugins import default_page_plugins from reflex.experimental.memo import ( + EXPERIMENTAL_MEMOS, ExperimentalMemoComponentDefinition, ExperimentalMemoDefinition, ExperimentalMemoFunctionDefinition, ) -from reflex.state import BaseState -from reflex.utils import console, path_ops -from reflex.utils.exec import is_prod_mode +from reflex.state import BaseState, code_uses_state_contexts +from reflex.utils import console, frontend_skeleton, path_ops, prerequisites +from reflex.utils.exec import get_compile_context, is_prod_mode from reflex.utils.prerequisites import get_web_dir +RADIX_THEMES_STYLESHEET = "@radix-ui/themes/styles.css" + + +def _set_progress_total( + progress: Progress | console.PoorProgress, + task: Any, + total: int, +) -> None: + """Update a task total for either rich or fallback progress bars.""" + progress.update(task, total=total) + def _apply_common_imports( imports: dict[str, list[ImportVar]], @@ -48,6 +67,36 @@ def _apply_common_imports( ) +def _extend_imports_in_place( + target: dict[str, list[ImportVar]], + import_dict: dict[str, Any] | tuple[tuple[str, Any], ...], +) -> None: + """Append imports to an existing parsed import dict. + + Args: + target: The import dictionary to update. + import_dict: The imports to append. + """ + for lib, fields in ( + import_dict if isinstance(import_dict, tuple) else import_dict.items() + ): + lib = ( + "$" + lib + if lib.startswith(("/utils/", "/components/", "/styles/", "/public/")) + else lib + ) + target_fields = target.setdefault(lib, []) + if isinstance(fields, (list, tuple, set)): + target_fields.extend( + ImportVar(field) if isinstance(field, str) else field + for field in fields + ) + else: + target_fields.append( + ImportVar(fields) if isinstance(fields, str) else fields + ) + + def _compile_document_root(root: Component) -> str: """Compile the document root. @@ -175,20 +224,23 @@ def _compile_page(component: BaseComponent) -> str: def compile_root_stylesheet( - stylesheets: list[str], reset_style: bool = True + stylesheets: list[str], + reset_style: bool = True, + plugins: Sequence[Plugin] | None = None, ) -> tuple[str, str]: """Compile the root stylesheet. Args: stylesheets: The stylesheets to include in the root stylesheet. reset_style: Whether to include CSS reset for margin and padding. + plugins: The effective plugins for the active compile. Returns: The path and code of the compiled root stylesheet. """ output_path = utils.get_root_stylesheet_path() - code = _compile_root_stylesheet(stylesheets, reset_style) + code = _compile_root_stylesheet(stylesheets, reset_style, plugins) return output_path, code @@ -228,15 +280,17 @@ def _validate_stylesheet(stylesheet_full_path: Path, assets_app_path: Path) -> N raise ValueError(msg) -RADIX_THEMES_STYLESHEET = "@radix-ui/themes/styles.css" - - -def _compile_root_stylesheet(stylesheets: list[str], reset_style: bool = True) -> str: +def _compile_root_stylesheet( + stylesheets: list[str], + reset_style: bool = True, + plugins: Sequence[Plugin] | None = None, +) -> str: """Compile the root stylesheet. Args: stylesheets: The stylesheets to include in the root stylesheet. reset_style: Whether to include CSS reset for margin and padding. + plugins: The effective plugins for the active compile. Returns: The compiled root stylesheet. @@ -252,14 +306,10 @@ def _compile_root_stylesheet(stylesheets: list[str], reset_style: bool = True) - # Reference the vendored style reset file (automatically copied from .templates/web) sheets.append(f"./{ResetStylesheet.FILENAME}") - sheets.extend( - [RADIX_THEMES_STYLESHEET] - + [ - sheet - for plugin in get_config().plugins - for sheet in plugin.get_stylesheet_paths() - ] - ) + active_plugins = get_config().plugins if plugins is None else plugins + sheets.extend([ + sheet for plugin in active_plugins for sheet in plugin.get_stylesheet_paths() + ]) failed_to_import_sass = False assets_app_path = Path.cwd() / constants.Dirs.APP_ASSETS @@ -331,7 +381,7 @@ def _compile_root_stylesheet(stylesheets: list[str], reset_style: bool = True) - return templates.styles_template(stylesheets=sheets) -def _compile_component(component: Component | StatefulComponent) -> str: +def _compile_component(component: Component) -> str: """Compile a single component. Args: @@ -346,149 +396,152 @@ def _compile_component(component: Component | StatefulComponent) -> str: def _compile_memo_components( components: Iterable[CustomComponent], experimental_memos: Iterable[ExperimentalMemoDefinition] = (), -) -> tuple[str, dict[str, list[ImportVar]]]: - """Compile the components. +) -> tuple[list[tuple[str, str]], dict[str, list[ImportVar]]]: + """Compile each memo/custom-component as its own module plus an index. + + Each memo lands in ``.web/<components>/<name>.jsx`` with only the imports + it actually uses. Experimental memo wrappers declare their ``library`` as + that per-memo file path so page-side imports resolve directly to the + individual module. + + The ``$/utils/components`` index only re-exports the legacy + ``@rx.memo`` custom components, which are the ones app-level code + (``root.jsx``) imports by name. Keeping experimental memos out of the + index is what lets root's ``import * as utils_components`` avoid + transitively dragging every page-specific memo into the always-loaded + chunk — the tree-shaking win of per-memo files relies on that. Args: components: The components to compile. experimental_memos: The experimental memos to compile. Returns: - The compiled components. + A list of ``(path, code)`` pairs to write — one per memo plus one + index — and the aggregated imports across all memo modules. """ - imports: dict[str, list[ImportVar]] = {} - component_renders = [] - function_renders = [] + per_memo_files: list[tuple[str, str]] = [] + # Only legacy custom components go through the index: they are the ones + # root.jsx/custom code imports by name from ``$/utils/components``. + # Experimental memos declare their library per-file (see + # ``_get_experimental_memo_component_class``) so pages import them + # directly and the index stays small. + index_entries: list[tuple[str, str]] = [] + aggregate_imports: dict[str, list[ImportVar]] = {} + + base_dir = utils.get_memo_components_dir() - # Compile each component. for component in components: component_render, component_imports = utils.compile_custom_component(component) - component_renders.append(component_render) - imports = utils.merge_imports(imports, component_imports) + name = component_render["name"] + code, file_imports = _compile_single_memo_component( + component_render, component_imports + ) + path = _memo_component_file_path(base_dir, name) + specifier = _memo_component_index_specifier(name) + per_memo_files.append((path, code)) + index_entries.append((name, specifier)) + _extend_imports_in_place(aggregate_imports, file_imports) for memo in experimental_memos: if isinstance(memo, ExperimentalMemoComponentDefinition): memo_render, memo_imports = utils.compile_experimental_component_memo(memo) - component_renders.append(memo_render) - imports = utils.merge_imports(imports, memo_imports) + name = memo_render["name"] + code, file_imports = _compile_single_memo_component( + memo_render, memo_imports + ) + path = _memo_component_file_path(base_dir, name) + per_memo_files.append((path, code)) + _extend_imports_in_place(aggregate_imports, file_imports) elif isinstance(memo, ExperimentalMemoFunctionDefinition): memo_render, memo_imports = utils.compile_experimental_function_memo(memo) - function_renders.append(memo_render) - imports = utils.merge_imports(imports, memo_imports) - - if component_renders: - imports = utils.merge_imports( - { - "react": [ImportVar(tag="memo")], - f"$/{constants.Dirs.STATE_PATH}": [ImportVar(tag="isTrue")], - }, - imports, - ) - _apply_common_imports(imports) - - dynamic_imports = { - comp_import: None - for comp_render in component_renders - if "dynamic_imports" in comp_render - for comp_import in comp_render["dynamic_imports"] - } - - custom_codes = { - comp_custom_code: None - for comp_render in component_renders - for comp_custom_code in comp_render.get("custom_code", []) - } - - # Compile the components page. - return ( - templates.memo_components_template( - imports=utils.compile_imports(imports), - components=component_renders, - functions=function_renders, - dynamic_imports=sorted(dynamic_imports), - custom_codes=custom_codes, - ), - imports, - ) + name = memo_render["name"] + code, file_imports = _compile_single_memo_function( + memo_render, memo_imports + ) + path = _memo_component_file_path(base_dir, name) + per_memo_files.append((path, code)) + _extend_imports_in_place(aggregate_imports, file_imports) + index_path = utils.get_components_path() + index_code = templates.memo_index_template(index_entries) + return [(index_path, index_code), *per_memo_files], aggregate_imports -def _get_shared_components_recursive( - component: BaseComponent, - rendered_components: dict[str, None], - all_import_dicts: list[ParsedImportDict], -): - """Get the shared components for a component and its children. - A shared component is a StatefulComponent that appears in 2 or more - pages and is a candidate for writing to a common file and importing - into each page where it is used. +def _compile_single_memo_component( + component_render: dict, + component_imports: dict[str, list[ImportVar]], +) -> tuple[str, dict[str, list[ImportVar]]]: + """Render one memoized component as a standalone module. Args: - component: The component to collect shared StatefulComponents for. - rendered_components: A dict to store the rendered shared components in. - all_import_dicts: A list to store the imports of all shared components in. - """ - for child in component.children: - # Depth-first traversal. - _get_shared_components_recursive(child, rendered_components, all_import_dicts) - - # When the component is referenced by more than one page, render it - # to be included in the STATEFUL_COMPONENTS module. - # Skip this step in dev mode, thereby avoiding potential hot reload errors for larger apps - if isinstance(component, StatefulComponent) and component.references > 1: - # Reset this flag to render the actual component. - component.rendered_as_shared = False + component_render: The component's render dict. + component_imports: The component's imports before common/common-memo + additions. - # Include dynamic imports in the shared component. - if dynamic_imports := component._get_all_dynamic_imports(): - rendered_components.update(dict.fromkeys(dynamic_imports)) + Returns: + The file contents and the full import dict used to compile it. + """ + imports = utils.merge_imports( + { + "react": [ImportVar(tag="memo")], + f"$/{constants.Dirs.STATE_PATH}": [ImportVar(tag="isTrue")], + }, + component_imports, + ) + _apply_common_imports(imports) + code = templates.memo_single_component_template( + imports=utils.compile_imports(imports), + component=component_render, + dynamic_imports=sorted(component_render.get("dynamic_imports", []) or []), + custom_codes=component_render.get("custom_code", []) or [], + ) + return code, imports - # Include custom code in the shared component. - rendered_components.update(component._get_all_custom_code(export=True)) - # Include all imports in the shared component. - all_import_dicts.append(component._get_all_imports()) +def _compile_single_memo_function( + function_render: dict, + function_imports: dict[str, list[ImportVar]], +) -> tuple[str, dict[str, list[ImportVar]]]: + """Render one function memo as a standalone module. - # Indicate that this component now imports from the shared file. - component.rendered_as_shared = True + Args: + function_render: The function's render dict. + function_imports: The function's imports. + Returns: + The file contents and the full import dict used to compile it. + """ + imports = utils.merge_imports({}, function_imports) + code = templates.memo_single_function_template( + imports=utils.compile_imports(imports), + function=function_render, + ) + return code, imports -def _compile_stateful_components( - page_components: list[BaseComponent], -) -> str: - """Walk the page components and extract shared stateful components. - Any StatefulComponent that is shared by more than one page will be rendered - to a separate module and marked rendered_as_shared so subsequent - renderings will import the component from the shared module instead of - directly including the code for it. +def _memo_component_file_path(base_dir: str, name: str) -> str: + """Return the on-disk path for a per-memo module. Args: - page_components: The Components or StatefulComponents to compile. + base_dir: The directory that holds per-memo files. + name: The memo's export name. Returns: - The rendered stateful components code. + The absolute path for the memo's ``.jsx`` file. """ - all_import_dicts = [] - rendered_components = {} + return str(Path(base_dir) / f"{name}{constants.Ext.JSX}") - for page_component in page_components: - _get_shared_components_recursive( - page_component, rendered_components, all_import_dicts - ) - # Don't import from the file that we're about to create. - all_imports = utils.merge_imports(*all_import_dicts) - all_imports.pop( - f"$/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None - ) - if rendered_components: - _apply_common_imports(all_imports) +def _memo_component_index_specifier(name: str) -> str: + """Return the module specifier the index uses to re-export a memo. - return templates.stateful_components_template( - imports=utils.compile_imports(all_imports), - memoized_code="\n".join(rendered_components), - ) + Args: + name: The memo's export name. + + Returns: + A relative specifier resolvable from the memo index module. + """ + return f"./{constants.PageNames.COMPONENTS}/{name}" def compile_document_root( @@ -521,7 +574,7 @@ def compile_document_root( return output_path, code -def compile_app(app_root: Component) -> tuple[str, str]: +def compile_app_root(app_root: Component) -> tuple[str, str]: """Compile the app root. Args: @@ -596,55 +649,49 @@ def compile_page(path: str, component: BaseComponent) -> tuple[str, str]: return output_path, code -def compile_memo_components( - components: Iterable[CustomComponent], - experimental_memos: Iterable[ExperimentalMemoDefinition] = (), -) -> tuple[str, str, dict[str, list[ImportVar]]]: - """Compile the custom components. +def compile_page_from_context(page_ctx: PageContext) -> tuple[str, str]: + """Compile a single page from a collected page context. Args: - components: The custom components to compile. - experimental_memos: The experimental memos to compile. + page_ctx: The collected page context to render. Returns: - The path and code of the compiled components. + The path and code of the compiled page. """ - # Get the path for the output file. - output_path = utils.get_components_path() - - # Compile the components. - code, imports = _compile_memo_components(components, experimental_memos) - return output_path, code, imports + output_path = utils.get_page_path(page_ctx.route) + imports = { + lib: list(fields) + for lib, fields in ( + page_ctx.frontend_imports or page_ctx.merged_imports(collapse=True) + ).items() + } + _apply_common_imports(imports) + code = templates.page_template( + imports=utils.compile_imports(imports), + dynamic_imports=sorted(page_ctx.dynamic_imports), + custom_codes=page_ctx.custom_code_dict(), + hooks=page_ctx.hooks, + render=page_ctx.root_component.render(), + ) + return output_path, code -def compile_stateful_components( - pages: Iterable[Component], - progress_function: Callable[[], None], -) -> tuple[str, str, list[BaseComponent]]: - """Separately compile components that depend on State vars. - StatefulComponents are compiled as their own component functions with their own - useContext declarations, which allows page components to be stateless and avoid - re-rendering along with parts of the page that actually depend on state. +def compile_memo_components( + components: Iterable[CustomComponent], + experimental_memos: Iterable[ExperimentalMemoDefinition] = (), +) -> tuple[list[tuple[str, str]], dict[str, list[ImportVar]]]: + """Compile the custom components into one module per memo plus an index. Args: - pages: The pages to extract stateful components from. - progress_function: A function to call to indicate progress, called once per page. + components: The custom components to compile. + experimental_memos: The experimental memos to compile. Returns: - The path and code of the compiled stateful components. + A list of ``(path, code)`` pairs (one per memo module and one index) + alongside the aggregated imports across all memo modules. """ - output_path = utils.get_stateful_components_path() - - page_components = [] - for page in pages: - # Compile the stateful components - page_component = StatefulComponent.compile_from(page) or page - progress_function() - page_components.append(page_component) - - code = _compile_stateful_components(page_components) if is_prod_mode() else "" - return output_path, code, page_components + return _compile_memo_components(components, experimental_memos) def purge_web_pages_dir(): @@ -661,7 +708,7 @@ def purge_web_pages_dir(): if TYPE_CHECKING: - from reflex.app import ComponentCallable, UnevaluatedPage + from reflex.app import App, ComponentCallable, UnevaluatedPage def _into_component_once( @@ -871,82 +918,362 @@ def compile_unevaluated_page( return component -class ExecutorSafeFunctions: - """Helper class to allow parallelisation of parts of the compilation process. +def _resolve_app_wrap_components( + app: App, + page_app_wrap_components: dict[tuple[int, str], Component], +) -> dict[tuple[int, str], Component]: + """Build the full app-wrap registry for compilation. + + Args: + app: The app being compiled. + page_app_wrap_components: App-wrap components collected from pages. + + Returns: + The merged app-wrap component registry. + """ + config = get_config() + + app_wrappers: dict[tuple[int, str], Component] = { + (0, "AppWrap"): AppWrap.create(), + } + app_wrappers.update(page_app_wrap_components) - This class (and its class attributes) are available at global scope. + if config.react_strict_mode: + from reflex_components_core.base.strict_mode import StrictMode - In a multiprocessing context (like when using a ProcessPoolExecutor), the content of this - global class is logically replicated to any FORKED process. + app_wrappers[200, "StrictMode"] = StrictMode.create() - How it works: - * Before the child process is forked, ensure that we stash any input data required by any future - function call in the child process. - * After the child process is forked, the child process will have a copy of the global class, which - includes the previously stashed input data. - * Any task submitted to the child process simply needs a way to communicate which input data the - requested function call requires. + if (toaster := app.toaster) is not None: + from reflex_base.components.component import memo - Why do we need this? Passing input data directly to child process often not possible because the input data is not picklable. - The mechanic described here removes the need to pickle the input data at all. + @memo + def memoized_toast_provider(): + return toaster - Limitations: - * This can never support returning unpicklable OUTPUT data. - * Any object mutations done by the child process will not propagate back to the parent process (fork goes one way!). + app_wrappers[44, "ToasterProvider"] = Fragment.create(memoized_toast_provider()) + for wrap_mapping in (app.app_wraps, app.extra_app_wraps): + for key, app_wrap in wrap_mapping.items(): + component = app_wrap(app._state is not None) + if component is not None: + app_wrappers[key] = component + + return app_wrappers + + +def _resolve_radix_themes_plugin( + app: App, + plugins: Sequence[Plugin], +) -> tuple[tuple[Plugin, ...], RadixThemesPlugin]: + """Resolve the effective Radix Themes plugin for the active compile. + + Returns: + The compiler plugin chain and the effective Radix Themes plugin. """ + explicit_plugin = next( + (plugin for plugin in plugins if isinstance(plugin, RadixThemesPlugin)), + None, + ) + if explicit_plugin is not None: + radix_plugin = explicit_plugin + plugin_chain = tuple(plugins) + else: + radix_plugin = RadixThemesPlugin.create_implicit() + plugin_chain = (*plugins, radix_plugin) + + if app.theme is not None: + radix_plugin.apply_app_theme(app.theme) + + return plugin_chain, radix_plugin + + +def compile_app( + app: App, + *, + prerender_routes: bool = False, + dry_run: bool = False, + use_rich: bool = True, +) -> None: + """Compile an app using the compiler plugin pipeline.""" + from reflex_base.components.dynamic import bundle_library, reset_bundled_libraries + from reflex_base.utils.exceptions import ReflexRuntimeError + + app._apply_decorated_pages() + app._pages = {} + + should_compile = app._should_compile() + backend_dir = prerequisites.get_backend_dir() + if not dry_run and not should_compile and backend_dir.exists(): + stateful_pages_marker = backend_dir / constants.Dirs.STATEFUL_PAGES + if stateful_pages_marker.exists(): + with stateful_pages_marker.open("r") as file: + stateful_pages = json.load(file) + for route in stateful_pages: + console.debug(f"BE Evaluating stateful page: {route}") + app._compile_page(route, save_page=False) + app._add_optional_endpoints() + return + + if constants.Page404.SLUG not in app._unevaluated_pages: + app.add_page(route=constants.Page404.SLUG) - COMPONENTS: dict[str, BaseComponent] = {} - UNCOMPILED_PAGES: dict[str, UnevaluatedPage] = {} - - @classmethod - def compile_page(cls, route: str) -> tuple[str, str]: - """Compile a page. - - Args: - route: The route of the page to compile. - - Returns: - The path and code of the compiled page. - """ - return compile_page(route, cls.COMPONENTS[route]) - - @classmethod - def compile_unevaluated_page( - cls, - route: str, - style: ComponentStyle, - theme: Component | None, - ) -> tuple[str, Component, tuple[str, str]]: - """Compile an unevaluated page. - - Args: - route: The route of the page to compile. - style: The style of the page. - theme: The theme of the page. - - Returns: - The route, compiled component, and compiled page. - """ - component = compile_unevaluated_page( - route, cls.UNCOMPILED_PAGES[route], style, theme + app.style = evaluate_style_namespaces(app.style) + config = get_config() + + if not should_compile and not dry_run: + with console.timing("Evaluate Pages (Backend)"): + for route in app._unevaluated_pages: + console.debug(f"Evaluating page: {route}") + app._compile_page(route, save_page=False) + + app._write_stateful_pages_marker() + app._add_optional_endpoints() + return + + progress = ( + Progress( + *Progress.get_default_columns()[:-1], + MofNCompleteColumn(), + TimeElapsedColumn(), + ) + if use_rich + else console.PoorProgress() + ) + fixed_steps = 7 + compiler_plugins, radix_themes_plugin = _resolve_radix_themes_plugin( + app, + config.plugins, + ) + reset_bundled_libraries() + for plugin in compiler_plugins: + for dependency in plugin.get_frontend_dependencies(): + bundle_library(dependency) + base_total = (len(app._unevaluated_pages) * 2) + fixed_steps + len(config.plugins) + progress.start() + task = progress.add_task("Compiling:", total=base_total) + compile_ctx = CompileContext( + app=app, + pages=list(app._unevaluated_pages.values()), + hooks=CompilerHooks( + plugins=default_page_plugins(style=app.style, plugins=compiler_plugins) + ), + ) + + with console.timing("Compile pages"), compile_ctx: + compile_ctx.compile( + evaluate_progress=lambda: progress.advance(task), + render_progress=lambda: progress.advance(task), ) - return route, component, compile_page(route, component) - @classmethod - def compile_theme(cls, style: ComponentStyle | None) -> tuple[str, str]: - """Compile the theme. + for route, page_ctx in compile_ctx.compiled_pages.items(): + app._check_routes_conflict(route) + if not isinstance(page_ctx.root_component, Component): + msg = ( + f"Compiled page {route!r} root must be a Component before it can " + "be registered on the app." + ) + raise TypeError(msg) + app._pages[route] = page_ctx.root_component + + app._stateful_pages.update(compile_ctx.stateful_routes) + app._write_stateful_pages_marker() + app._add_optional_endpoints() + app._validate_var_dependencies() + + if config.show_built_with_reflex is None: + if ( + get_compile_context() == constants.CompileContext.DEPLOY + and prerequisites.get_user_tier() in ["pro", "team", "enterprise"] + ): + config.show_built_with_reflex = False + else: + config.show_built_with_reflex = True + + if is_prod_mode() and config.show_built_with_reflex: + app._setup_sticky_badge() + + progress.advance(task) + + compile_results = [ + (page_ctx.output_path, page_ctx.output_code) + for page_ctx in compile_ctx.compiled_pages.values() + if page_ctx.output_path is not None and page_ctx.output_code is not None + ] + + # Reinitialize vite config in case runtime options have changed. + compile_results.append(( + constants.ReactRouter.VITE_CONFIG_FILE, + frontend_skeleton._compile_vite_config(config), + )) - Args: - style: The style to compile. + all_imports = compile_ctx.all_imports - Returns: - The path and code of the compiled theme. + if app._state is None and any( + code_uses_state_contexts(page_ctx.output_code or "") + for page_ctx in compile_ctx.compiled_pages.values() + ): + msg = ( + "To access rx.State in frontend components, at least one " + "subclass of rx.State must be defined in the app." + ) + raise ReflexRuntimeError(msg) + progress.advance(task) + + app_wrappers = _resolve_app_wrap_components(app, compile_ctx.app_wrap_components) + app_root = app._app_root(app_wrappers) + all_imports = utils.merge_imports(all_imports, app_root._get_all_imports()) + + memo_component_files, memo_components_imports = compile_memo_components( + dict.fromkeys(CUSTOM_COMPONENTS.values()), + ( + *tuple(EXPERIMENTAL_MEMOS.values()), + *tuple(compile_ctx.auto_memo_components.values()), + ), + ) + compile_results.extend(memo_component_files) + all_imports = utils.merge_imports(all_imports, memo_components_imports) + progress.advance(task) + + compile_results.append( + compile_document_root( + app.head_components, + html_lang=app.html_lang, + html_custom_attrs=( + {"suppressHydrationWarning": True, **app.html_custom_attrs} + if app.html_custom_attrs + else {"suppressHydrationWarning": True} + ), + ) + ) + progress.advance(task) + + assets_src = Path.cwd() / constants.Dirs.APP_ASSETS + if assets_src.is_dir() and not dry_run: + with console.timing("Copy assets"): + path_ops.update_directory_tree( + src=assets_src, + dest=Path.cwd() / prerequisites.get_web_dir() / constants.Dirs.PUBLIC, + ) + + save_tasks: list[ + tuple[ + Callable[..., list[tuple[str, str]] | tuple[str, str] | None], + tuple[Any, ...], + dict[str, Any], + ] + ] = [] + modify_files_tasks: list[tuple[str, str, Callable[[str], str]]] = [] + + def add_save_task( + task_fn: Callable[..., list[tuple[str, str]] | tuple[str, str] | None], + /, + *args: Any, + **kwargs: Any, + ) -> None: + save_tasks.append((task_fn, args, kwargs)) + + for plugin in config.plugins: + plugin.pre_compile( + add_save_task=add_save_task, + add_modify_task=lambda *args, plugin=plugin: modify_files_tasks.append(( + plugin.__class__.__module__ + plugin.__class__.__name__, + *args, + )), + radix_themes_plugin=radix_themes_plugin, + unevaluated_pages=list(app._unevaluated_pages.values()), + ) + + if save_tasks: + _set_progress_total(progress, task, base_total + len(save_tasks)) + + progress.advance(task, advance=len(config.plugins)) + + compile_results.append( + compile_root_stylesheet( + app.stylesheets, + app.reset_style, + plugins=compiler_plugins, + ) + ) + progress.advance(task) + + compile_results.append(compile_theme(app.style)) + progress.advance(task) + + for task_fn, args, kwargs in save_tasks: + result = task_fn(*args, **kwargs) + if result is None: + progress.advance(task) + continue + if isinstance(result, list): + compile_results.extend(result) + else: + compile_results.append(result) + progress.advance(task) + + compile_results.append( + compile_contexts(app._state, radix_themes_plugin.get_theme()) + ) + progress.advance(task) + + compile_results.append(compile_app_root(app_root)) + progress.advance(task) + + progress.stop() + + if dry_run: + return + + with console.timing("Install Frontend Packages"): + app._get_frontend_packages(all_imports) + + frontend_skeleton.update_react_router_config( + prerender_routes=prerender_routes, + ) + + if is_prod_mode(): + purge_web_pages_dir() + else: + keep_files = [Path(output_path) for output_path, _ in compile_results] + for page_file in Path( + prerequisites.get_web_dir() / constants.Dirs.PAGES / constants.Dirs.ROUTES + ).rglob("*"): + if page_file.is_file() and page_file not in keep_files: + page_file.unlink() + + output_mapping: dict[Path, str] = {} + for output_path, code in compile_results: + path = utils.resolve_path_of_web_dir(output_path) + if path in output_mapping: + console.warn( + f"Path {path} has two different outputs. The first one will be used." + ) + else: + output_mapping[path] = code + + for plugin in config.plugins: + for static_file_path, content in plugin.get_static_assets(): + path = utils.resolve_path_of_web_dir(static_file_path) + if path in output_mapping: + console.warn( + f"Plugin {plugin.__class__.__name__} is trying to write to {path} but it already exists. The plugin file will be ignored." + ) + else: + output_mapping[path] = ( + content.decode("utf-8") if isinstance(content, bytes) else content + ) + + for plugin_name, file_path, modify_fn in modify_files_tasks: + path = utils.resolve_path_of_web_dir(file_path) + file_content = output_mapping.get(path) + if file_content is None: + if path.exists(): + file_content = path.read_text() + else: + msg = f"Plugin {plugin_name} is trying to modify {path} but it does not exist." + raise FileNotFoundError(msg) + output_mapping[path] = modify_fn(file_content) - Raises: - ValueError: If the style is not set. - """ - if style is None: - msg = "STYLE should be set" - raise ValueError(msg) - return compile_theme(style) + with console.timing("Write to Disk"): + for output_path, code in output_mapping.items(): + utils.write_file(output_path, code) diff --git a/reflex/compiler/plugins/__init__.py b/reflex/compiler/plugins/__init__.py new file mode 100644 index 00000000000..92e34115e3e --- /dev/null +++ b/reflex/compiler/plugins/__init__.py @@ -0,0 +1,30 @@ +"""Built-in compiler plugins for single-pass page compilation.""" + +from reflex_base.plugins import ( + BaseContext, + CompileContext, + CompilerHooks, + ComponentAndChildren, + PageContext, +) + +from .builtin import ( + ApplyStylePlugin, + DefaultCollectorPlugin, + DefaultPagePlugin, + default_page_plugins, +) +from .memoize import MemoizeStatefulPlugin + +__all__ = [ + "ApplyStylePlugin", + "BaseContext", + "CompileContext", + "CompilerHooks", + "ComponentAndChildren", + "DefaultCollectorPlugin", + "DefaultPagePlugin", + "MemoizeStatefulPlugin", + "PageContext", + "default_page_plugins", +] diff --git a/reflex/compiler/plugins/builtin.py b/reflex/compiler/plugins/builtin.py new file mode 100644 index 00000000000..a4b326be4ab --- /dev/null +++ b/reflex/compiler/plugins/builtin.py @@ -0,0 +1,409 @@ +"""Built-in compiler plugins and the default plugin pipeline.""" + +from __future__ import annotations + +import dataclasses +from collections.abc import Callable, Sequence +from typing import Any + +from reflex_base.components.component import BaseComponent, Component, ComponentStyle +from reflex_base.config import get_config +from reflex_base.plugins import CompileContext, PageContext, PageDefinition, Plugin +from reflex_base.plugins.base import HookOrder +from reflex_base.utils.format import make_default_page_title +from reflex_base.utils.imports import collapse_imports, merge_imports +from reflex_base.vars import VarData +from reflex_components_core.base.fragment import Fragment + +from reflex.compiler import utils + + +@dataclasses.dataclass(frozen=True, slots=True) +class DefaultPagePlugin(Plugin): + """Evaluate an unevaluated page into a mutable page context.""" + + def eval_page( + self, + page_fn: Any, + /, + *, + page: PageDefinition, + **kwargs: Any, + ) -> PageContext: + """Evaluate the page function and attach legacy page metadata. + + Returns: + The evaluated page context. + """ + from reflex.compiler import compiler + + try: + component = compiler.into_component(page_fn) + component = Fragment.create(component) + + title = getattr(page, "title", None) + meta_args = { + "title": ( + title + if title is not None + else make_default_page_title(get_config().app_name, page.route) + ), + "image": getattr(page, "image", ""), + "meta": getattr(page, "meta", ()), + } + if (description := getattr(page, "description", None)) is not None: + meta_args["description"] = description + + utils.add_meta(component, **meta_args) + except Exception as err: + if hasattr(err, "add_note"): + err.add_note(f"Happened while evaluating page {page.route!r}") + raise + + return PageContext( + name=getattr(page_fn, "__name__", page.route), + route=page.route, + root_component=component, + ) + + +@dataclasses.dataclass(frozen=True, slots=True) +class ApplyStylePlugin(Plugin): + """Apply app-level styles in the descending phase of the walk.""" + + style: ComponentStyle | None = None + + @staticmethod + def _apply_style( + comp: Component, style: ComponentStyle, page_context: PageContext + ) -> Component | None: + """Apply app-level styles to a single component. + + Args: + comp: The component to style. + style: The app-level component style map. + page_context: The active page context, used to obtain a page-local + clone before rewriting ``style``. + + Returns: + A page-local clone with the merged style, or ``None`` when the + component has no type-level or app-level style to apply. + """ + if type(comp)._add_style != Component._add_style: + msg = "Do not override _add_style directly. Use add_style instead." + raise UserWarning(msg) + + new_style = comp._add_style() + component_style = comp._get_component_style(style) + if not new_style and not component_style: + return None + + style_vars = [new_style._var_data] + if component_style: + new_style.update(component_style) + style_vars.append(component_style._var_data) + new_style.update(comp.style) + style_vars.append(comp.style._var_data) + new_style._var_data = VarData.merge(*style_vars) + + owned = page_context.own(comp) + owned.style = new_style + return owned + + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: Any, + in_prop_tree: bool = False, + ) -> BaseComponent | None: + """Apply the non-recursive portion of ``_add_style_recursive``. + + Returns: + A page-local clone carrying the merged style, or ``None`` when no + style change applies to this component. + """ + if self.style is not None and isinstance(comp, Component) and not in_prop_tree: + return self._apply_style(comp, self.style, page_context) + return None + + def _compiler_bind_enter_component( + self, + page_context: PageContext, + compile_context: CompileContext, + ) -> Callable[[BaseComponent, bool], BaseComponent | None]: + """Bind a positional fast-path enter hook for style application. + + Returns: + A compiled enter hook that only takes hot-loop positional state. + """ + style = self.style + if style is None: + + def enter_component( + comp: BaseComponent, + in_prop_tree: bool, + ) -> BaseComponent | None: + return None + + return enter_component + + apply_style = self._apply_style + + def enter_component( + comp: BaseComponent, + in_prop_tree: bool, + ) -> BaseComponent | None: + if not isinstance(comp, Component) or in_prop_tree: + return None + return apply_style(comp, style, page_context) + + return enter_component + + +@dataclasses.dataclass(frozen=True, slots=True) +class DefaultCollectorPlugin(Plugin): + """Collect page artifacts in one fused enter/leave hook pair.""" + + # Run after replacing leave hooks so collected imports/custom-code reflect + # the final post-replacement component (e.g. memoize wrappers). + _compiler_leave_component_order = HookOrder.POST + _compiler_can_replace_leave_component = False + + def leave_component( + self, + comp: BaseComponent, + children: tuple[BaseComponent, ...], + /, + *, + page_context: PageContext, + compile_context: Any, + in_prop_tree: bool = False, + ) -> None: + """Collect imports and page artifacts for the active component node.""" + if not isinstance(comp, Component): + return + + imports = comp._get_imports() + if imports: + self._extend_imports(page_context.frontend_imports, imports) + + self._collect_component_custom_code(page_context.module_code, comp) + + if not in_prop_tree: + self._collect_component_hooks(page_context.hooks, comp) + + if ( + type(comp)._get_app_wrap_components + is not Component._get_app_wrap_components + ): + self._collect_app_wrap_components( + page_context.app_wrap_components, + comp, + ) + + if (dynamic_import := comp._get_dynamic_imports()) is not None: + page_context.dynamic_imports.add(dynamic_import) + + if (ref := comp.get_ref()) is not None: + page_context.refs[ref] = None + + def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + """Collapse collected imports into a single legacy-shaped entry.""" + if page_ctx.frontend_imports: + collapsed_imports = collapse_imports( + merge_imports(page_ctx.frontend_imports, *page_ctx.imports) + if page_ctx.imports + else page_ctx.frontend_imports + ) + page_ctx.frontend_imports = collapsed_imports + page_ctx.imports = [collapsed_imports] + return + + page_ctx.imports = ( + [collapse_imports(merge_imports(*page_ctx.imports))] + if page_ctx.imports + else [] + ) + + def _compiler_bind_leave_component( + self, + page_context: PageContext, + compile_context: CompileContext, + ) -> Callable[[BaseComponent, tuple[BaseComponent, ...], bool], None]: + """Bind a positional fast-path leave hook for artifact collection. + + Returns: + A compiled leave hook that only takes hot-loop positional state. + """ + frontend_imports = page_context.frontend_imports + module_code = page_context.module_code + hooks = page_context.hooks + dynamic_imports = page_context.dynamic_imports + refs = page_context.refs + app_wrap_components = page_context.app_wrap_components + extend_imports = self._extend_imports + collect_component_hooks = self._collect_component_hooks + collect_component_custom_code = self._collect_component_custom_code + collect_app_wrap_components = self._collect_app_wrap_components + base_get_app_wrap_components = Component._get_app_wrap_components + seen_app_wrap_methods: set[object] = set() + + def leave_component( + comp: BaseComponent, + children: tuple[BaseComponent, ...], + in_prop_tree: bool, + ) -> None: + if not isinstance(comp, Component): + return + + imports_for_component = comp._get_imports() + if imports_for_component: + extend_imports(frontend_imports, imports_for_component) + + collect_component_custom_code(module_code, comp) + + if not in_prop_tree: + collect_component_hooks(hooks, comp) + + app_wrap_method = type(comp)._get_app_wrap_components + if ( + app_wrap_method is not base_get_app_wrap_components + and app_wrap_method not in seen_app_wrap_methods + ): + seen_app_wrap_methods.add(app_wrap_method) + collect_app_wrap_components(app_wrap_components, comp) + + dynamic_import = comp._get_dynamic_imports() + if dynamic_import is not None: + dynamic_imports.add(dynamic_import) + + ref = comp.get_ref() + if ref is not None: + refs[ref] = None + + return leave_component + + @staticmethod + def _collect_component_hooks( + page_hooks: dict[str, VarData | None], + component: Component, + ) -> None: + """Collect hooks for one structural-tree component in legacy order.""" + page_hooks.update(component._get_hooks_internal()) + if (user_hooks := component._get_hooks()) is not None: + page_hooks[user_hooks] = None + page_hooks.update(component._get_added_hooks()) + + @staticmethod + def _extend_imports( + target: dict[str, list[Any]], + source: dict[str, list[Any]], + ) -> None: + """Extend a parsed import mapping in place.""" + for lib, fields in source.items(): + target.setdefault(lib, []).extend(fields) + + @staticmethod + def _collect_component_custom_code( + module_code: dict[str, None], + component: Component, + ) -> None: + """Collect custom code contributed directly by one component. + + The compiler walker visits every structural child and every component + in prop subtrees, firing ``leave_component`` on each — so this helper + only handles the current node and does not recurse. + """ + if (custom_code := component._get_custom_code()) is not None: + module_code[custom_code] = None + + for clz in component._iter_parent_classes_with_method("add_custom_code"): + for item in clz.add_custom_code(component): + module_code[item] = None + + def _collect_app_wrap_components( + self, + page_app_wrap_components: dict[tuple[int, str], Component], + component: Component, + ) -> None: + """Collect app-wrap components for a structural-tree component.""" + direct_wrappers = component._get_app_wrap_components() + if not direct_wrappers: + return + + ignore_ids = {id(wrapper) for wrapper in page_app_wrap_components.values()} + page_app_wrap_components.update(direct_wrappers) + for wrapper in direct_wrappers.values(): + wrapper_id = id(wrapper) + if wrapper_id in ignore_ids: + continue + ignore_ids.add(wrapper_id) + self._collect_wrapper_subtree_into( + wrapper, + ignore_ids, + page_app_wrap_components, + ) + + @staticmethod + def _collect_wrapper_subtree_into( + component: Component, + ignore_ids: set[int], + components: dict[tuple[int, str], Component], + ) -> None: + """Collect nested app-wrap components into ``components``.""" + direct_wrappers = component._get_app_wrap_components() + for key, wrapper in direct_wrappers.items(): + wrapper_id = id(wrapper) + if wrapper_id in ignore_ids: + continue + ignore_ids.add(wrapper_id) + components[key] = wrapper + DefaultCollectorPlugin._collect_wrapper_subtree_into( + wrapper, + ignore_ids, + components, + ) + + for child in component.children: + if not isinstance(child, Component): + continue + child_id = id(child) + if child_id in ignore_ids: + continue + ignore_ids.add(child_id) + DefaultCollectorPlugin._collect_wrapper_subtree_into( + child, + ignore_ids, + components, + ) + + +def default_page_plugins( + *, + style: ComponentStyle | None = None, + plugins: Sequence[Plugin] = (), +) -> tuple[Plugin, ...]: + """Return the default compiler plugin ordering for page compilation.""" + from reflex.compiler.plugins.memoize import MemoizeStatefulPlugin + + chain: list[Plugin] = [*plugins, DefaultPagePlugin()] + if style is not None: + chain.append(ApplyStylePlugin(style=style)) + chain.extend((DefaultCollectorPlugin(), MemoizeStatefulPlugin())) + return tuple(chain) + + +__all__ = [ + "ApplyStylePlugin", + "DefaultCollectorPlugin", + "DefaultPagePlugin", + "default_page_plugins", +] diff --git a/reflex/compiler/plugins/memoize.py b/reflex/compiler/plugins/memoize.py new file mode 100644 index 00000000000..b596b147a79 --- /dev/null +++ b/reflex/compiler/plugins/memoize.py @@ -0,0 +1,399 @@ +"""MemoizeStatefulPlugin — auto-memoize stateful components with ``rx._x.memo``. + +This plugin replaces the legacy ``StatefulComponent`` wrapping pass. It +participates in the normal single-pass walk via ``enter_component`` and inserts +per-subtree ``{children}``-pass-through wrappers built on the experimental +memo infrastructure. The wrapped subtree remains in the tree for the normal +walker descent, so downstream plugins (e.g. ``DefaultCollectorPlugin``) still +see the original components and collect their imports/hooks as usual. + +Each unique subtree shape contributes: + +- One generated experimental memo component definition, compiled into the + shared ``$/utils/components`` module. +- ``useCallback`` hook lines for each non-lifecycle event trigger, emitted into + the generated memo body so handler hooks stay inside that rendering domain. + +No shared ``stateful_components`` file is produced. +""" + +from __future__ import annotations + +import dataclasses +from typing import Any + +from reflex_base.components.component import ( + BaseComponent, + Component, + _deterministic_hash, + _hash_str, +) +from reflex_base.components.memoize_helpers import ( + MemoizationStrategy, + fix_event_triggers_for_memo, + get_memoization_strategy, + is_snapshot_boundary, +) +from reflex_base.constants.compiler import MemoizationDisposition +from reflex_base.plugins import ComponentAndChildren, PageContext +from reflex_base.plugins.base import Plugin +from reflex_base.utils import format + +from reflex.experimental.memo import create_passthrough_component_memo + + +def _compute_memo_tag(component: Component) -> str | None: + """Compute a stable tag name for a memoizable component. + + Returns ``None`` for components that render empty (non-visual components + are never memoized). + + The class qualname is encoded directly in the tag prefix so that distinct + classes which render identically never collide on a tag. Tag collision + would silently share a single cached memo wrapper across classes and drop + the later class's class-level metadata (e.g. ``_get_app_wrap_components``, + which carries providers like ``UploadFilesProvider`` that must reach the + app root). Baking the qualname into the prefix avoids re-concatenating + the rendered JSX into the hash input on every call. + + Args: + component: The component to name. + + Returns: + The stable tag name, or ``None`` if the component renders empty. + """ + rendered_code = component.render() + if not rendered_code: + return None + code_hash = _hash_str(_deterministic_hash(rendered_code)) + return format.format_state_name( + f"{type(component).__qualname__}_{component.tag or 'Comp'}_{code_hash}" + ).capitalize() + + +def _subtree_has_reactive_data( + component: Component, _cache: dict[int, bool] | None = None +) -> bool: + """Whether ``component``'s subtree carries reactive signals worth memoizing. + + No-arg event handlers (``on_click=State.ping``) contribute hooks only via + ``event_triggers`` / ``_get_events_hooks``, not as a Var, so the per-Var + scan must be paired with an explicit ``event_triggers`` check. + + ``useRef`` from a static ``id`` prop is intentionally ignored — it lives + in ``_get_hooks_internal``, not in any Var, so static-id-only elements + don't surface here and aren't flagged. + + Args: + component: The component whose subtree to inspect. + _cache: Internal ``id()``-keyed cache of per-subtree results so + components reachable via overlapping ``var_data.components`` and + ``children`` paths are evaluated once. ``False`` is also used as + a transient placeholder while a subtree is being computed to + break cycles. + + Returns: + True if the subtree carries event triggers, explicit hooks, or any + Var whose merged var_data has ``state`` or ``hooks``. + """ + if _cache is None: + _cache = {} + key = id(component) + cached = _cache.get(key) + if cached is not None: + return cached + # Placeholder breaks cycles: a subtree that references itself is + # treated as non-reactive on the recursive arm; the real result for + # this node is written back below. + _cache[key] = False + result = _component_subtree_is_reactive(component, _cache) + _cache[key] = result + return result + + +def _component_subtree_is_reactive( + component: Component, _cache: dict[int, bool] +) -> bool: + """Inner walk for :func:`_subtree_has_reactive_data` (uncached node check). + + Internal hooks (``_get_hooks_internal``) cover event-trigger callbacks, + lifecycle hooks (``on_mount``/``on_unmount``), and Var-derived hooks + (state context, client state, custom). The static ``id`` ref hook is + explicitly subtracted so an id-only element does not flag as reactive. + + Args: + component: The component to inspect. + _cache: Shared cache passed through recursive calls. + + Returns: + True if ``component`` itself or any reachable descendant carries + reactive signals. + """ + ref_hook = component._get_ref_hook() + ref_hook_key = str(ref_hook) if ref_hook is not None else None + for hook_key in component._get_hooks_internal(): + if hook_key != ref_hook_key: + return True + if component._get_hooks() is not None or component._get_added_hooks(): + return True + for var in component._get_vars(include_children=False): + var_data = var._get_all_var_data() + if var_data is None: + continue + if var_data.state or var_data.hooks: + return True + for comp in var_data.components: + if isinstance(comp, Component) and _subtree_has_reactive_data(comp, _cache): + return True + for child in component.children: + if isinstance(child, Component) and _subtree_has_reactive_data(child, _cache): + return True + return False + + +def _should_memoize(component: Component) -> bool: + """Decide whether ``component`` is a candidate for auto-memoization. + + Snapshot boundaries (``recursive=False``) suppress their descendants, + so a stateful subtree must trigger wrapping at the boundary itself — + otherwise the state read leaks into the page module. Other components + are evaluated from their own props/triggers; descendants are visited + independently by the walker. + + Args: + component: The candidate component. + + Returns: + True if the component should be wrapped in a memo definition. + """ + from reflex_components_core.base.bare import Bare + from reflex_components_core.core.cond import Cond + from reflex_components_core.core.match import Match + + strategy = get_memoization_strategy(component) + + if component._memoization_mode.disposition == MemoizationDisposition.NEVER: + return False + if isinstance(component, Bare): + # A stateful value will be wrapped in a separate component. Match the + # per-Var predicate used by ``_subtree_has_reactive_data`` so a Bare + # whose Var carries only imports (no state/hooks) is not memoized. + contents_var_data = component.contents._get_all_var_data() + if contents_var_data is not None: + if contents_var_data.state or contents_var_data.hooks: + return True + for embedded in contents_var_data.components: + if isinstance(embedded, Component) and _subtree_has_reactive_data( + embedded + ): + return True + # Cond and Match render conditional branch JSX from their own props rather + # than from a tag, so they have no `tag` but still must be considered. + if component.tag is None and not isinstance(component, (Cond, Match)): + return False + if component._memoization_mode.disposition == MemoizationDisposition.ALWAYS: + return True + + # Direct Vars only (component's own props, style, class_name, id, etc.). + # Match the per-Var predicate used by ``_subtree_has_reactive_data`` + # var_data carrying only imports is not reactive. + for prop_var in component._get_vars(include_children=False): + var_data = prop_var._get_all_var_data() + if var_data is None: + continue + if var_data.state or var_data.hooks: + return True + for embedded in var_data.components: + if isinstance(embedded, Component) and _subtree_has_reactive_data(embedded): + return True + + if strategy is MemoizationStrategy.SNAPSHOT and not is_snapshot_boundary(component): + return True + + if is_snapshot_boundary(component) and _subtree_has_reactive_data(component): + return True + + # Components with event triggers are always memoized (to wrap callbacks). + return bool(component.event_triggers) + + +@dataclasses.dataclass(frozen=True, slots=True) +class MemoizeStatefulPlugin(Plugin): + """Auto-memoize stateful components with experimental-memo wrappers. + + Registered in ``default_page_plugins`` before ``DefaultCollectorPlugin``. + Components either render as passthrough memo wrappers or snapshot memo + wrappers (see ``get_memoization_strategy``): + + - Snapshot wrappers (``MemoizationLeaf``-style boundaries and structural + ``Foreach`` wrappers): wrapped in ``enter_component`` + and returned with empty structural children. The walker skips descent, so + hooks attached to the captured body are compiled into the memo body only. + - Passthrough wrappers are wrapped in + ``leave_component`` after descendants have already compiled, so any inner + memo wrappers flow into this wrapper's children. + + Descendants of a snapshot boundary are never independently memoized; the + boundary owns the wrapping decision for its whole subtree. This is tracked + via ``PageContext.memoize_suppressor_stack`` — a stack of component ids + that pushed suppression, popped in ``leave_component`` when the matching + component leaves. + """ + + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: Any, + in_prop_tree: bool = False, + ) -> BaseComponent | ComponentAndChildren | None: + """Memoize snapshot-boundary subtrees before descent. + + Snapshot boundaries (``MemoizationLeaf``-style, see + ``is_snapshot_boundary``) stash state-referencing hooks inside + internally-built structural children. If we waited until + ``leave_component`` to swap the boundary for its memo wrapper, the + walker would have already descended and the collector plugin would + have pulled those hooks into page scope. Returning the wrapper with + empty structural children here causes the walker to skip the descent + entirely — the boundary's full snapshot lives only in the memo + component definition compiled separately. + + Non-boundary components are handled in ``leave_component`` so their + already-compiled children flow into the wrapper. + + Args: + comp: The component being visited. + page_context: The active page context. + compile_context: The active compile context. + in_prop_tree: Whether the component is in a prop subtree. + + Returns: + A ``(wrapper, ())`` replacement for memoized boundaries, otherwise + ``None``. + """ + if in_prop_tree: + return None + if not isinstance(comp, Component): + return None + if page_context.memoize_suppressor_stack: + return None + strategy = get_memoization_strategy(comp) + if strategy is not MemoizationStrategy.SNAPSHOT: + return None + snapshot_boundary = is_snapshot_boundary(comp) + + if not _should_memoize(comp): + # Boundary not worth wrapping — still suppress descendants so + # they don't memoize independently of the boundary's subtree. + if snapshot_boundary: + page_context.memoize_suppressor_stack.append(id(comp)) + return None + + wrapper = self._build_wrapper( + comp, + page_context, + compile_context, + ) + return None if wrapper is None else (wrapper, ()) + + def leave_component( + self, + comp: BaseComponent, + children: tuple[BaseComponent, ...], + /, + *, + page_context: PageContext, + compile_context: Any, + in_prop_tree: bool = False, + ) -> BaseComponent | ComponentAndChildren | None: + """Wrap non-boundary memoizables and pop any suppression this component pushed. + + Args: + comp: The component being visited. + children: Its compiled children (unused; the wrapper reads from + ``comp.children`` which the walker has already updated). + page_context: The active page context. + compile_context: The active compile context. + in_prop_tree: Whether the component is in a prop subtree. + + Returns: + The memo wrapper for non-boundary memoizables, else ``None``. + """ + if in_prop_tree: + return None + if not isinstance(comp, Component): + return None + + stack = page_context.memoize_suppressor_stack + if stack and stack[-1] == id(comp): + stack.pop() + + if stack: + return None + + if len(children) != len(comp.children) or any( + compiled_child is not current_child + for compiled_child, current_child in zip( + children, comp.children, strict=True + ) + ): + comp = page_context.own(comp) + comp.children = list(children) + + strategy = get_memoization_strategy(comp) + if strategy is MemoizationStrategy.SNAPSHOT: + return None + + if not _should_memoize(comp): + return None + + return self._build_wrapper(comp, page_context, compile_context) + + @staticmethod + def _build_wrapper( + comp: Component, + page_context: PageContext, + compile_context: Any, + ) -> BaseComponent | None: + """Return the memo wrapper component for ``comp``, or ``None`` if untagged. + + Rewrites ``comp.event_triggers`` on a page-local clone via + :func:`fix_event_triggers_for_memo` so the memo body renders the + memoized ``useCallback`` forms, and registers the memo definition on + ``compile_context`` so the memo module compile pass emits it. + + Args: + comp: The component being memoized. + page_context: The active page context. + compile_context: The active compile context. + + Returns: + The wrapper instance, or ``None`` if the component's render is + empty and has no meaningful tag. + """ + tag = _compute_memo_tag(comp) + if tag is None: + return None + + comp = fix_event_triggers_for_memo(comp, page_context) + + compile_context.memoize_wrappers[tag] = None + # Passthrough memo definitions capture app-specific event/state vars, so + # they must be rebuilt for each compile instead of shared globally. + wrapper_factory, definition = create_passthrough_component_memo(tag, comp) + compile_context.auto_memo_components[tag] = definition + + wrapper = wrapper_factory() + # The wrapper has no structural children at the page level, but parents + # walking ``_get_all_refs`` (e.g. ``Form._get_form_refs`` collecting + # ref_<id> mappings into ``handleSubmit``) need to see refs from the + # wrapped subtree. Delegate ref collection to the original component + # so descendants inside the memo body remain reachable for ref lookup. + object.__setattr__(wrapper, "_get_all_refs", comp._get_all_refs) + return wrapper + + +__all__ = ["MemoizeStatefulPlugin"] diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index bb812c67c5f..c2bdf618650 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -336,11 +336,15 @@ def compile_custom_component( render = component.get_component() # Get the imports. - imports: ParsedImportDict = { - lib: fields - for lib, fields in render._get_all_imports().items() - if lib != component.library - } + imports: ParsedImportDict = {} + for lib, fields in render._get_all_imports().items(): + if lib != component.library: + imports[lib] = fields + continue + + filtered_fields = [field for field in fields if field.tag != component.tag] + if filtered_fields: + imports[lib] = filtered_fields imports.setdefault("@emotion/react", []).append(ImportVar("jsx")) @@ -373,15 +377,47 @@ def _apply_component_style_for_compile(component: Component) -> Component: Returns: The styled component tree. """ + component._add_style_recursive(_app_style()) + return component + + +def _apply_root_style(component: Component) -> None: + """Merge app-level style into ``component.style`` without recursing. + + Used for passthrough memo bodies where descendants render (and are styled) + in the page scope — only the root's style needs merging here. + + Args: + component: The root component to style in place. + """ + if type(component)._add_style != Component._add_style: + msg = "Do not override _add_style directly. Use add_style instead." + raise UserWarning(msg) + style = _app_style() + new_style = component._add_style() + style_vars = [new_style._var_data] + component_style = component._get_component_style(style) + if component_style: + new_style.update(component_style) + style_vars.append(component_style._var_data) + new_style.update(component.style) + style_vars.append(component.style._var_data) + new_style._var_data = VarData.merge(*style_vars) + component.style = new_style + + +def _app_style() -> ComponentStyle | Style: + """Return the active app-level component style map, or an empty one. + + Returns: + The app-level style map. + """ try: from reflex.utils.prerequisites import get_and_validate_app - style = get_and_validate_app().app.style + return get_and_validate_app().app.style except Exception: - style = {} - - component._add_style_recursive(style) - return component + return {} def compile_experimental_component_memo( @@ -395,12 +431,49 @@ def compile_experimental_component_memo( Returns: A tuple of the compiled component definition and its imports. """ - render = _apply_component_style_for_compile(copy.deepcopy(definition.component)) - + hole_child = definition.passthrough_hole_child + if hole_child is not None: + # Passthrough memo: shallow-copy the root only — ``render.children`` + # still aliases the user-authored descendants so root-level walkers + # (e.g. ``Form._get_form_refs``) can introspect the real subtree, but + # we skip the O(n) deepcopy + recursive style pass. Descendants are + # rendered AND styled in the page scope, not here, so only the root + # needs app-level style merged. + render = copy.copy(definition.component) + _apply_root_style(render) + + hooks = _root_only_hooks(render) + custom_code = _root_only_custom_code(render) + dynamic_imports = _root_only_dynamic_imports(render) + # Strings returned by the root's ``add_hooks`` can reference symbols + # (``refs``, ``StateContexts``, etc.) that normally reach this module + # through descendants' ``_get_hooks_imports`` / ``_get_imports``. JS + # imports are side-effect-free and dedup cleanly, so pulling the + # whole subtree's imports here is safe even when some go unused. + # ``_get_all_imports`` is read-only on the descendants, so the shallow + # aliasing above is fine. + all_imports = render._get_all_imports() + + # Swap children for JSX render: the memo body template emits a + # ``{children}`` hole in place of the real descendants. + render.children = [hole_child] + rendered = render.render() + else: + render = _apply_component_style_for_compile(copy.deepcopy(definition.component)) + rendered = render.render() + hooks = render._get_all_hooks() + custom_code = render._get_all_custom_code() + dynamic_imports = render._get_all_dynamic_imports() + all_imports = render._get_all_imports() + + # Each experimental memo now lives in ``web/utils/components/<name>.jsx``, + # so importing the ``$/utils/components`` index from this file is only + # circular when ``<name>`` itself appears in that index — i.e. a legacy + # ``@rx.memo`` wrapper file. For auto-memo wrappers around legacy custom + # components, the index import is legitimate and must be preserved. + self_module = f"$/{constants.Dirs.COMPONENTS_PATH}/{definition.export_name}" imports: ParsedImportDict = { - lib: fields - for lib, fields in render._get_all_imports().items() - if lib != f"$/{constants.Dirs.COMPONENTS_PATH}" + lib: fields for lib, fields in all_imports.items() if lib != self_module } imports.setdefault("@emotion/react", []).append(ImportVar("jsx")) @@ -424,15 +497,69 @@ def compile_experimental_component_memo( fields=tuple(signature_fields), rest=rest_param.placeholder_name if rest_param is not None else None, ).to_javascript(), - "render": render.render(), - "hooks": render._get_all_hooks(), - "custom_code": render._get_all_custom_code(), - "dynamic_imports": render._get_all_dynamic_imports(), + "render": rendered, + "hooks": hooks, + "custom_code": custom_code, + "dynamic_imports": dynamic_imports, }, imports, ) +def _root_only_hooks(component: Component) -> dict[str, VarData | None]: + """Return hooks contributed by ``component`` itself, not its subtree. + + Used by the passthrough memo compile path where descendants render in the + page scope — only the wrapper's own hooks (internal + ``add_hooks`` + + explicit ``_get_hooks``) belong in the memo body. + + Args: + component: The root component whose own hooks to collect. + + Returns: + The root-level hook map, keyed by hook source string. + """ + code: dict[str, VarData | None] = {} + code.update(component._get_hooks_internal()) + explicit = component._get_hooks() + if explicit is not None: + code[explicit] = None + code.update(component._get_added_hooks()) + return code + + +def _root_only_custom_code(component: Component) -> dict[str, None]: + """Return custom code contributed by ``component`` itself, not its subtree. + + Args: + component: The root component whose own custom code to collect. + + Returns: + The root-level custom code snippets. + """ + code: dict[str, None] = {} + own = component._get_custom_code() + if own is not None: + code[own] = None + for clz in component._iter_parent_classes_with_method("add_custom_code"): + for item in clz.add_custom_code(component): + code[item] = None + return code + + +def _root_only_dynamic_imports(component: Component) -> set[str]: + """Return dynamic imports contributed by ``component`` itself. + + Args: + component: The root component whose own dynamic imports to collect. + + Returns: + The root-level dynamic imports. + """ + own = component._get_dynamic_imports() + return {own} if own else set() + + def compile_experimental_function_memo( definition: ExperimentalMemoFunctionDefinition, ) -> tuple[dict, ParsedImportDict]: @@ -446,10 +573,13 @@ def compile_experimental_function_memo( """ imports: ParsedImportDict = {} if var_data := definition.function._get_all_var_data(): + # Per-file memo modules live at ``$/utils/components/<name>``; strip + # only a self-import to this function memo's own module. + self_module = f"$/{constants.Dirs.COMPONENTS_PATH}/{definition.python_name}" imports = { lib: list(fields) for lib, fields in dict(var_data.imports).items() - if lib != f"$/{constants.Dirs.COMPONENTS_PATH}" + if lib != self_module } return ( @@ -656,16 +786,15 @@ def get_components_path() -> str: ) -def get_stateful_components_path() -> str: - """Get the path of the compiled stateful components. +def get_memo_components_dir() -> str: + """Get the directory that holds per-memo module files. Returns: - The path of the compiled stateful components. + The directory used for per-memo ``.jsx`` modules re-exported by the + top-level components index. """ return str( - get_web_dir() - / constants.Dirs.UTILS - / (constants.PageNames.STATEFUL_COMPONENTS + constants.Ext.JSX) + get_web_dir() / constants.Dirs.UTILS / constants.PageNames.COMPONENTS, ) diff --git a/reflex/experimental/client_state.py b/reflex/experimental/client_state.py index 95fdbf06885..e24315b4734 100644 --- a/reflex/experimental/client_state.py +++ b/reflex/experimental/client_state.py @@ -22,28 +22,36 @@ } -def _client_state_ref(var_name: str) -> str: - """Get the ref path for a ClientStateVar. +def _client_state_ref(var_name: str) -> Var: + """Get the ref accessor Var for a ClientStateVar. Args: var_name: The name of the variable. Returns: - An accessor for ClientStateVar ref as a string. + A Var that accesses the ClientStateVar ref slot, carrying the + ``refs`` import from ``$/utils/state``. """ - return f"refs['_client_state_{var_name}']" + return Var( + _js_expr=f"refs['_client_state_{var_name}']", + _var_data=VarData(imports=_refs_import), + ) -def _client_state_ref_dict(var_name: str) -> str: - """Get the ref path for a ClientStateVar. +def _client_state_ref_dict(var_name: str) -> Var: + """Get the per-instance ref-dict accessor Var for a ClientStateVar. Args: var_name: The name of the variable. Returns: - An accessor for ClientStateVar ref as a string. + A Var that accesses the ClientStateVar's per-instance ref dict, + carrying the ``refs`` import from ``$/utils/state``. """ - return f"refs['_client_state_dict_{var_name}']" + return Var( + _js_expr=f"refs['_client_state_dict_{var_name}']", + _var_data=VarData(imports=_refs_import), + ) @dataclasses.dataclass( @@ -132,6 +140,10 @@ def create( } if global_ref: arg_name = get_unique_variable_name() + setter_ref = _client_state_ref(setter_name) + var_ref = _client_state_ref(var_name) + var_dict_ref = _client_state_ref_dict(var_name) + setter_dict_ref = _client_state_ref_dict(setter_name) func = ArgsFunctionOperationBuilder.create( args_names=(arg_name,), return_expr=Var("Array.prototype.forEach.call") @@ -140,13 +152,13 @@ def create( ( Var("Object.values") .to(FunctionVar) - .call(Var(_client_state_ref_dict(setter_name))) + .call(setter_dict_ref) .to(list) .to(list) ) - + Var.create([ - Var(f"(value) => {{ {_client_state_ref(var_name)} = value; }}") - ]).to(list), + + Var.create([Var(f"(value) => {{ {var_ref} = value; }}")]).to( + list + ), ArgsFunctionOperationBuilder.create( args_names=("setter",), return_expr=Var("setter").to(FunctionVar).call(Var(arg_name)), @@ -154,17 +166,16 @@ def create( ), ) - hooks[f"{_client_state_ref(setter_name)} = {func!s}"] = None - hooks[f"{_client_state_ref(var_name)} ??= {var_name!s}"] = None - hooks[f"{_client_state_ref_dict(var_name)} ??= {{}}"] = None - hooks[f"{_client_state_ref_dict(setter_name)} ??= {{}}"] = None - hooks[ - f"{_client_state_ref_dict(var_name)}[{id_name}] = {_client_state_ref(var_name)}" - ] = None - hooks[ - f"{_client_state_ref_dict(setter_name)}[{id_name}] = {setter_name}" - ] = None - imports.update(_refs_import) + hooks[f"{setter_ref!s} = {func!s}"] = setter_ref._get_all_var_data() + hooks[f"{var_ref!s} ??= {var_name!s}"] = var_ref._get_all_var_data() + hooks[f"{var_dict_ref!s} ??= {{}}"] = var_dict_ref._get_all_var_data() + hooks[f"{setter_dict_ref!s} ??= {{}}"] = setter_dict_ref._get_all_var_data() + hooks[f"{var_dict_ref!s}[{id_name}] = {var_ref!s}"] = VarData.merge( + var_dict_ref._get_all_var_data(), var_ref._get_all_var_data() + ) + hooks[f"{setter_dict_ref!s}[{id_name}] = {setter_name}"] = ( + setter_dict_ref._get_all_var_data() + ) return cls( _js_expr="null", _setter_name=setter_name, @@ -192,20 +203,12 @@ def value(self) -> Var: Returns: an accessor for the client state variable. """ - return ( - Var( - _js_expr=( - _client_state_ref_dict(self._getter_name) + f"[{self._id_name}]" - if self._global_ref - else self._getter_name - ), - _var_data=self._var_data, - ) - .to(self._var_type) - ._replace( - merge_var_data=VarData(imports=_refs_import if self._global_ref else {}) - ) + js_expr = ( + f"{_client_state_ref_dict(self._getter_name)}[{self._id_name}]" + if self._global_ref + else self._getter_name ) + return Var(_js_expr=js_expr, _var_data=self._var_data).to(self._var_type) def set_value(self, value: Any = NoValue) -> Var: """Set the value of the client state variable. @@ -220,12 +223,10 @@ def set_value(self, value: Any = NoValue) -> Var: Returns: A special EventChain Var which will set the value when triggered. """ - var_data = VarData(imports=_refs_import if self._global_ref else {}) - setter = ( - Var(_client_state_ref(self._setter_name)) + _client_state_ref(self._setter_name) if self._global_ref - else Var(self._setter_name, _var_data=var_data) + else Var(self._setter_name) ).to(FunctionVar) if value is not NoValue: diff --git a/reflex/experimental/memo.py b/reflex/experimental/memo.py index 7dee0c72eea..0dbf91bfbf4 100644 --- a/reflex/experimental/memo.py +++ b/reflex/experimental/memo.py @@ -5,13 +5,22 @@ import dataclasses import inspect from collections.abc import Callable +from copy import copy from functools import cache, update_wrapper from typing import Any, get_args, get_origin, get_type_hints from reflex_base import constants from reflex_base.components.component import Component from reflex_base.components.dynamic import bundled_libraries -from reflex_base.constants.compiler import SpecialAttributes +from reflex_base.components.memoize_helpers import ( + MemoizationStrategy, + get_memoization_strategy, +) +from reflex_base.constants.compiler import ( + MemoizationDisposition, + MemoizationMode, + SpecialAttributes, +) from reflex_base.constants.state import CAMEL_CASE_MEMO_MARKER from reflex_base.utils import format from reflex_base.utils.imports import ImportVar @@ -69,12 +78,33 @@ class ExperimentalMemoComponentDefinition(ExperimentalMemoDefinition): export_name: str component: Component + # For passthrough wrappers built by the auto-memoize plugin: the + # ``Bare``-wrapped ``{children}`` placeholder used when rendering the memo + # body. The ``component`` keeps its ORIGINAL children so compile-time + # walkers (``Form._get_form_refs`` etc.) can introspect the subtree; the + # compiler swaps to this placeholder only for the JSX render and for + # imports collection, so descendants emit their refs/imports/hooks in the + # page scope rather than being duplicated inside the memo body. + passthrough_hole_child: Component | None = None class ExperimentalMemoComponent(Component): """A rendered instance of an experimental memo component.""" library = f"$/{constants.Dirs.COMPONENTS_PATH}" + _memoization_mode = MemoizationMode(disposition=MemoizationDisposition.NEVER) + + def _validate_component_children(self, children: list[Component]) -> None: + """Skip direct parent/child validation for memo wrapper instances. + + Experimental memos wrap an underlying compiled component definition. + The runtime wrapper should not interpose on `_valid_parents` checks for + the authored subtree because the wrapper itself is not the semantic + parent in the user-authored component tree. + + Args: + children: The children of the component (ignored). + """ def _post_init(self, **kwargs): """Initialize the experimental memo component. @@ -119,22 +149,45 @@ def _post_init(self, **kwargs): @cache def _get_experimental_memo_component_class( export_name: str, + wrapped_component_type: type[Component] = Component, ) -> type[ExperimentalMemoComponent]: """Get the component subclass for an experimental memo export. + Class-level metadata that the compiler reads via ``type(comp)._get_*()`` + (notably ``_get_app_wrap_components``, which carries providers like + ``UploadFilesProvider`` that must reach the app root) is inherited from + ``wrapped_component_type`` so the wrapper is a transparent substitute for + the original in the compile tree. + Args: export_name: The exported React component name. + wrapped_component_type: The class of the component being memoized. + Defaults to ``Component`` for memos that don't wrap a user + component (e.g. function memos, raw passthroughs). Returns: A cached component subclass with the tag set at class definition time. """ + attrs: dict[str, Any] = { + "__module__": __name__, + "tag": export_name, + # Point each memo at its own per-file module so pages import directly + # from ``$/utils/components/<name>`` rather than through the index. + # Per-file import paths give Vite distinct module boundaries per + # memo, enabling actual code-split by page. + "library": f"$/{constants.Dirs.COMPONENTS_PATH}/{export_name}", + } + if ( + wrapped_component_type._get_app_wrap_components + is not Component._get_app_wrap_components + ): + attrs["_get_app_wrap_components"] = staticmethod( + wrapped_component_type._get_app_wrap_components + ) return type( f"ExperimentalMemoComponent_{export_name}", (ExperimentalMemoComponent,), - { - "__module__": __name__, - "tag": export_name, - }, + attrs, ) @@ -300,7 +353,9 @@ def _imported_function_var(name: str, return_type: Any) -> FunctionVar: name, _var_type=ReflexCallable[Any, return_type], _var_data=VarData( - imports={f"$/{constants.Dirs.COMPONENTS_PATH}": [ImportVar(tag=name)]} + imports={ + f"$/{constants.Dirs.COMPONENTS_PATH}/{name}": [ImportVar(tag=name)] + } ), ) @@ -319,7 +374,7 @@ def _component_import_var(name: str) -> Var: _var_type=type[Component], _var_data=VarData( imports={ - f"$/{constants.Dirs.COMPONENTS_PATH}": [ImportVar(tag=name)], + f"$/{constants.Dirs.COMPONENTS_PATH}/{name}": [ImportVar(tag=name)], "@emotion/react": [ImportVar(tag="jsx")], } ), @@ -906,7 +961,9 @@ def __call__(self, *children: Any, **props: Any) -> ExperimentalMemoComponent: raise TypeError(msg) # Build the component props passed into the memo wrapper. - return _get_experimental_memo_component_class(definition.export_name)._create( + return _get_experimental_memo_component_class( + definition.export_name, type(definition.component) + )._create( children=list(children), memo_definition=definition, **explicit_values, @@ -950,6 +1007,67 @@ def _create_component_wrapper( return _ExperimentalMemoComponentWrapper(definition) +def create_passthrough_component_memo( + export_name: str, + component: Component, +) -> tuple[ + Callable[..., ExperimentalMemoComponent], + ExperimentalMemoComponentDefinition, +]: + """Create an unregistered ``@rx._x.memo``-style passthrough component memo. + + This is used by compiler auto-memoization so generated wrappers compile + through the experimental memo pipeline instead of emitting ad-hoc page-local + ``React.memo`` declarations. + + Args: + export_name: The exported memo component name. + component: The component to wrap. + + Returns: + The callable memo wrapper and its component definition. + """ + # Snapshot-boundary components (see ``is_snapshot_boundary``) own their + # subtree — the ``.children`` slot is internal machinery from the + # subclass's ``.create`` (e.g. the dropzone Div built inside + # ``Upload.create``), not a user content hole. The memoize plugin wraps + # the boundary with no structural children on the page side, so the memo + # body renders the full snapshot rather than a ``{children}``-holed + # template. + render_snapshot = ( + get_memoization_strategy(component) is MemoizationStrategy.SNAPSHOT + ) + + captured_hole_child: list[Component] = [] + + def passthrough(children: Var[Component]) -> Component: + new_component = copy(component) + if render_snapshot: + return new_component + # Keep ``new_component.children`` as the ORIGINAL children so + # compile-time walkers that introspect the subtree (e.g. Form's + # ``_get_form_refs``) see the real descendants. The ``{children}`` + # hole lives on the definition and the compiler swaps it in only for + # JSX render / imports collection. + captured_hole_child.append(Bare.create(children)) + return new_component + + passthrough.__name__ = format.to_snake_case(export_name) + passthrough.__qualname__ = passthrough.__name__ + passthrough.__module__ = __name__ + + definition = _create_component_definition(passthrough, Component) + replacements: dict[str, Any] = {} + if definition.export_name != export_name: + replacements["export_name"] = export_name + if captured_hole_child: + replacements["passthrough_hole_child"] = captured_hole_child[0] + if replacements: + definition = dataclasses.replace(definition, **replacements) + + return _create_component_wrapper(definition), definition + + def memo(fn: Callable[..., Any]) -> Callable[..., Any]: """Create an experimental memo from a function. @@ -986,3 +1104,14 @@ def memo(fn: Callable[..., Any]) -> Callable[..., Any]: f"got `{return_annotation}`." ) raise TypeError(msg) + + +__all__ = [ + "EXPERIMENTAL_MEMOS", + "ExperimentalMemoComponent", + "ExperimentalMemoComponentDefinition", + "ExperimentalMemoDefinition", + "ExperimentalMemoFunctionDefinition", + "create_passthrough_component_memo", + "memo", +] diff --git a/reflex/plugins/__init__.py b/reflex/plugins/__init__.py index ac50297d040..9bd4335ab02 100644 --- a/reflex/plugins/__init__.py +++ b/reflex/plugins/__init__.py @@ -1,7 +1,12 @@ """Re-export from reflex_base.plugins.""" from reflex_base.plugins import ( + BaseContext, CommonContext, + CompileContext, + CompilerHooks, + ComponentAndChildren, + PageContext, Plugin, PreCompileContext, SitemapPlugin, @@ -12,11 +17,18 @@ tailwind_v3, tailwind_v4, ) +from reflex_components_radix.plugin import RadixThemesPlugin __all__ = [ + "BaseContext", "CommonContext", + "CompileContext", + "CompilerHooks", + "ComponentAndChildren", + "PageContext", "Plugin", "PreCompileContext", + "RadixThemesPlugin", "SitemapPlugin", "TailwindV3Plugin", "TailwindV4Plugin", diff --git a/tests/benchmarks/fixtures.py b/tests/benchmarks/fixtures.py index c20ac177660..63469330109 100644 --- a/tests/benchmarks/fixtures.py +++ b/tests/benchmarks/fixtures.py @@ -1,10 +1,14 @@ +from collections.abc import Callable from dataclasses import dataclass -from typing import cast +from typing import Any, cast import pytest from pydantic import BaseModel +from reflex_base.components.component import BaseComponent, Component +from reflex_base.plugins import CompileContext, PageContext import reflex as rx +from reflex.compiler.plugins import DefaultCollectorPlugin class SideBarState(rx.State): @@ -221,6 +225,53 @@ class NestedElement(BaseModel): value: list[int] +@dataclass(frozen=True, slots=True) +class ImportOnlyCollectorPlugin(DefaultCollectorPlugin): + """Collect only imports — same scope as Component._get_all_imports. + + Inherits import collection from DefaultCollectorPlugin but disables + hooks, custom code, app_wrap, and stateful code rendering. + """ + + _compiler_stateful_only_leave_component = False + + def leave_component(self, *_args: Any, **_kwargs: Any) -> None: + """No-op: skip stateful code rendering.""" + + def _compiler_bind_leave_component( + self, *_args: Any, **_kwargs: Any + ) -> Callable[..., None]: + """Return a no-op leave hook.""" + + def _noop(*_a: Any, **_kw: Any) -> None: + pass + + return _noop + + def _compiler_bind_enter_component( + self, + page_context: PageContext, + compile_context: CompileContext, + ) -> Callable[[BaseComponent, bool], None]: + del compile_context + + frontend_imports = page_context.frontend_imports + extend_imports = self._extend_imports + + def enter_component( + comp: BaseComponent, + in_prop_tree: bool, + ) -> None: + if not isinstance(comp, Component) or in_prop_tree: + return + + imports = comp._get_imports() + if imports: + extend_imports(frontend_imports, imports) + + return enter_component + + class BenchmarkState(rx.State): """State for the benchmark.""" diff --git a/tests/benchmarks/test_compilation.py b/tests/benchmarks/test_compilation.py index 7ad0f666f76..f9b1f134e5b 100644 --- a/tests/benchmarks/test_compilation.py +++ b/tests/benchmarks/test_compilation.py @@ -1,7 +1,16 @@ +import copy + from pytest_codspeed import BenchmarkFixture from reflex_base.components.component import Component +from reflex_base.plugins import CompileContext, CompilerHooks, PageContext + +from reflex.app import UnevaluatedPage +from reflex.compiler import compiler +from reflex.compiler.compiler import _compile_page +from reflex.compiler.plugins import DefaultCollectorPlugin, default_page_plugins +from reflex.compiler.plugins.memoize import MemoizeStatefulPlugin -from reflex.compiler.compiler import _compile_page, _compile_stateful_components +from .fixtures import ImportOnlyCollectorPlugin def import_templates(): @@ -9,17 +18,126 @@ def import_templates(): import reflex.compiler.templates # noqa: F401 +def _compile_single_pass_page_ctx(component: Component) -> PageContext: + # The single-pass compiler mutates the tree in place when it inserts memo + # wrappers, so benchmark iterations need an isolated copy of the input. + component = copy.deepcopy(component) + page_ctx = PageContext( + name="benchmark", + route="/benchmark", + root_component=component, + ) + hooks = CompilerHooks( + plugins=(MemoizeStatefulPlugin(), DefaultCollectorPlugin()), + ) + compile_ctx = CompileContext(pages=[], hooks=hooks) + + with compile_ctx, page_ctx: + page_ctx.root_component = hooks.compile_component( + page_ctx.root_component, + page_context=page_ctx, + compile_context=compile_ctx, + ) + hooks.compile_page(page_ctx, compile_context=compile_ctx) + + return page_ctx + + +def _get_imports_single_pass(component: Component) -> dict: + """Collect only imports via a single-pass walk — comparable to _get_all_imports. + + Returns: + The collapsed import dict for the page. + """ + page_ctx = PageContext( + name="benchmark", + route="/benchmark", + root_component=component, + ) + hooks = CompilerHooks(plugins=(ImportOnlyCollectorPlugin(),)) + compile_ctx = CompileContext(pages=[], hooks=hooks) + + with compile_ctx, page_ctx: + hooks.compile_component( + component, + page_context=page_ctx, + compile_context=compile_ctx, + ) + hooks.compile_page(page_ctx, compile_context=compile_ctx) + + return page_ctx.frontend_imports + + +def _compile_page_single_pass(component: Component) -> str: + page_ctx = _compile_single_pass_page_ctx(component) + page_ctx.frontend_imports = page_ctx.merged_imports(collapse=True) + return compiler.compile_page_from_context(page_ctx)[1] + + +def _compile_page_full_context(unevaluated_page) -> str: + page = UnevaluatedPage(route="/benchmark", component=unevaluated_page) + compile_ctx = CompileContext( + pages=[page], + hooks=CompilerHooks(plugins=default_page_plugins()), + ) + + with compile_ctx: + compiled_pages = compile_ctx.compile() + + output_code = compiled_pages["/benchmark"].output_code + if output_code is None: + msg = "CompileContext did not produce output code for the benchmark page." + raise RuntimeError(msg) + return output_code + + def test_compile_page(evaluated_page: Component, benchmark: BenchmarkFixture): import_templates() benchmark(lambda: _compile_page(evaluated_page)) -def test_compile_stateful(evaluated_page: Component, benchmark: BenchmarkFixture): +def test_compile_page_single_pass( + evaluated_page: Component, + benchmark: BenchmarkFixture, +): import_templates() - benchmark(lambda: _compile_stateful_components([evaluated_page])) + benchmark(lambda: _compile_page_single_pass(evaluated_page)) + + +def test_compile_page_full_context( + unevaluated_page, + benchmark: BenchmarkFixture, +): + import_templates() + + benchmark(lambda: _compile_page_full_context(unevaluated_page)) def test_get_all_imports(evaluated_page: Component, benchmark: BenchmarkFixture): benchmark(lambda: evaluated_page._get_all_imports()) + + +def test_get_all_imports_single_pass( + evaluated_page: Component, + benchmark: BenchmarkFixture, +): + benchmark(lambda: _get_imports_single_pass(evaluated_page)) + + +def test_compile_single_pass_all_artifacts( + evaluated_page: Component, + benchmark: BenchmarkFixture, +): + """Full single-pass collecting all artifacts (imports, hooks, code, app_wrap). + + This is the fair comparison for the total work the old multi-pass approach + did across _get_all_imports + _get_all_hooks + _get_all_custom_code + + _get_all_app_wrap_components. + """ + benchmark( + lambda: _compile_single_pass_page_ctx(evaluated_page).merged_imports( + collapse=True + ) + ) diff --git a/tests/benchmarks/test_evaluate.py b/tests/benchmarks/test_evaluate.py index 7af08a3592e..b533b34c415 100644 --- a/tests/benchmarks/test_evaluate.py +++ b/tests/benchmarks/test_evaluate.py @@ -2,9 +2,22 @@ from pytest_codspeed import BenchmarkFixture from reflex_base.components.component import Component +from reflex_base.plugins import CompilerHooks + +from reflex.app import UnevaluatedPage +from reflex.compiler.plugins import DefaultPagePlugin def test_evaluate_page( unevaluated_page: Callable[[], Component], benchmark: BenchmarkFixture ): benchmark(unevaluated_page) + + +def test_evaluate_page_single_pass( + unevaluated_page: Callable[[], Component], + benchmark: BenchmarkFixture, +): + hooks = CompilerHooks(plugins=(DefaultPagePlugin(),)) + page = UnevaluatedPage(route="/benchmark", component=unevaluated_page) + benchmark(lambda: hooks.eval_page(page.component, page=page)) diff --git a/tests/integration/test_auto_memo.py b/tests/integration/test_auto_memo.py new file mode 100644 index 00000000000..121184f493e --- /dev/null +++ b/tests/integration/test_auto_memo.py @@ -0,0 +1,73 @@ +"""Integration tests for compiler-generated experimental memos.""" + +from collections.abc import Generator + +import pytest +from selenium.webdriver.common.by import By + +from reflex.testing import AppHarness + +from .utils import poll_for_navigation + + +def AutoMemoAcrossPagesApp(): + """Reflex app that shares one stateful subtree across two pages.""" + import reflex as rx + + def shared_counter() -> rx.Component: + return rx.text(rx.State.router.page.raw_path, id="shared-value") + + def index() -> rx.Component: + return rx.vstack( + shared_counter(), + rx.link("Other", href="/other", id="to-other"), + ) + + def other() -> rx.Component: + return rx.vstack( + shared_counter(), + rx.link("Home", href="/", id="to-home"), + ) + + app = rx.App() + app.add_page(index) + app.add_page(other, route="/other") + + +@pytest.fixture +def auto_memo_app(tmp_path) -> Generator[AppHarness, None, None]: + """Start AutoMemoAcrossPagesApp app at tmp_path via AppHarness. + + Yields: + A running AppHarness instance. + """ + with AppHarness.create( + root=tmp_path, + app_source=AutoMemoAcrossPagesApp, + ) as harness: + yield harness + + +def test_auto_memo_shared_across_pages(auto_memo_app: AppHarness): + """Shared stateful subtrees compile once and render correctly on both pages.""" + assert auto_memo_app.app_instance is not None, "app is not running" + + web_sources = "\n".join( + path.read_text() for path in (auto_memo_app.app_path / ".web").rglob("*.jsx") + ) + assert "$/utils/components" in web_sources + assert "$/utils/stateful_components" not in web_sources + + driver = auto_memo_app.frontend() + shared_value = AppHarness.poll_for_or_raise_timeout( + lambda: driver.find_element(By.ID, "shared-value") + ) + assert auto_memo_app.poll_for_content(shared_value, exp_not_equal="") == "/" + + with poll_for_navigation(driver): + driver.find_element(By.ID, "to-other").click() + + shared_value = AppHarness.poll_for_or_raise_timeout( + lambda: driver.find_element(By.ID, "shared-value") + ) + assert "other" in auto_memo_app.poll_for_content(shared_value, exp_not_equal="") diff --git a/tests/integration/test_dynamic_components.py b/tests/integration/test_dynamic_components.py index cbd60fc227a..3de8eef034f 100644 --- a/tests/integration/test_dynamic_components.py +++ b/tests/integration/test_dynamic_components.py @@ -17,6 +17,7 @@ def DynamicComponents(): class DynamicComponentsState(rx.State): value: int = 10 + count: int = 0 button: rx.Component = rx.button( "Click me", @@ -34,6 +35,15 @@ def got_clicked(self): }, ) + @rx.event + def set_count(self, count: int): + """Set the counter value. + + Args: + count: The new counter value. + """ + self.count = count + @rx.var def client_token_component(self) -> rx.Component: return rx.vstack( @@ -53,6 +63,27 @@ def client_token_component(self) -> rx.Component: ), ) + @rx.var + def counter_component(self) -> rx.Component: + """Get a dynamic counter component with event handlers. + + Returns: + The dynamic counter component. + """ + return rx.hstack( + rx.button( + "-", + id="decrement", + on_click=DynamicComponentsState.set_count(self.count - 1), + ), + rx.text(self.count, id="count"), + rx.button( + "+", + id="increment", + on_click=DynamicComponentsState.set_count(self.count + 1), + ), + ) + app = rx.App() def factorial(n: int) -> int: @@ -65,6 +96,7 @@ def index(): return rx.vstack( DynamicComponentsState.client_token_component, DynamicComponentsState.button, + DynamicComponentsState.counter_component, rx.text( DynamicComponentsState._evaluate( lambda state: factorial(state.value), of_type=int @@ -135,12 +167,28 @@ def test_dynamic_components(driver, dynamic_components: AppHarness): assert update_button update_button.click() - assert ( - dynamic_components.poll_for_content(button, exp_not_equal="Click me") - == "Clicked" + assert AppHarness.poll_for_or_raise_timeout( + lambda: driver.find_element(By.ID, "button").text == "Clicked" ) factorial = AppHarness.poll_for_or_raise_timeout( lambda: driver.find_element(By.ID, "factorial") ) assert factorial.text == "3628800" + + count = AppHarness.poll_for_or_raise_timeout( + lambda: driver.find_element(By.ID, "count") + ) + assert count.text == "0" + + increment = driver.find_element(By.ID, "increment") + increment.click() + assert AppHarness.poll_for_or_raise_timeout( + lambda: driver.find_element(By.ID, "count").text == "1" + ) + + decrement = driver.find_element(By.ID, "decrement") + decrement.click() + assert AppHarness.poll_for_or_raise_timeout( + lambda: driver.find_element(By.ID, "count").text == "0" + ) diff --git a/tests/integration/test_var_operations.py b/tests/integration/test_var_operations.py index 1f752b71f2a..409a0838b2e 100644 --- a/tests/integration/test_var_operations.py +++ b/tests/integration/test_var_operations.py @@ -31,6 +31,7 @@ class VarOperationState(rx.State): int_var1: rx.Field[int] = rx.field(10) int_var2: rx.Field[int] = rx.field(5) int_var3: rx.Field[int] = rx.field(7) + match_selector: rx.Field[int] = rx.field(2) float_var1: rx.Field[float] = rx.field(10.5) float_var2: rx.Field[float] = rx.field(5.5) long_float: rx.Field[float] = rx.field(13212312312.1231231) @@ -660,6 +661,17 @@ def index(): ), id="foreach_in_match", ), + # stateful component branches in a match + rx.box( + rx.match( + VarOperationState.match_selector, + (0, rx.text(VarOperationState.int_var1 + 1)), + (1, rx.text(VarOperationState.int_var2 + 2)), + (2, rx.text(VarOperationState.str_var1.upper())), + rx.text(VarOperationState.list3.length()), + ), + id="stateful_match_three_cases", + ), # Literal range var in a foreach rx.box(rx.foreach(range(42, 80, 27), rx.text.span), id="range_in_foreach1"), rx.box(rx.foreach(range(42, 80, 3), rx.text.span), id="range_in_foreach2"), @@ -993,6 +1005,8 @@ def test_var_operations(driver, var_operations: AppHarness): ("obj_length", "3"), # foreach in a match ("foreach_in_match", "first\nsecond\nthird"), + # stateful branch components in a match + ("stateful_match_three_cases", "FIRST"), # literal range in a foreach ("range_in_foreach1", "4269"), ("range_in_foreach2", "42454851545760636669727578"), diff --git a/tests/integration/tests_playwright/test_cond_match.py b/tests/integration/tests_playwright/test_cond_match.py new file mode 100644 index 00000000000..3791d792213 --- /dev/null +++ b/tests/integration/tests_playwright/test_cond_match.py @@ -0,0 +1,114 @@ +"""Integration tests for stateful ``rx.cond`` and ``rx.match`` rendering.""" + +from collections.abc import Generator + +import pytest +from playwright.sync_api import Page, expect + +from reflex.testing import AppHarness + + +def CondMatchApp(): + """App exercising conditional rendering across state transitions.""" + import reflex as rx + + class CondMatchState(rx.State): + val_a: str = "A" + val_b: str = "B" + + @rx.event + def select_a(self): + self.val_a = "A" + + @rx.event + def select_b(self): + self.val_a = "B" + + @rx.event + def select_c(self): + self.val_a = "C" + + def index(): + return rx.box( + rx.hstack( + rx.button("A", on_click=CondMatchState.select_a, id="select-a"), + rx.button("B", on_click=CondMatchState.select_b, id="select-b"), + rx.button("C", on_click=CondMatchState.select_c, id="select-c"), + ), + rx.text(CondMatchState.val_a, id="current-value"), + rx.box( + rx.cond( + CondMatchState.val_a == "A", + rx.text(CondMatchState.val_a, id="cond-true"), + rx.text(CondMatchState.val_b, id="cond-false"), + ), + id="cond-container", + ), + rx.box( + rx.match( + CondMatchState.val_a, + ("A", rx.text(CondMatchState.val_a + " is selected", id="match-a")), + ("B", rx.text(CondMatchState.val_b + " is selected", id="match-b")), + rx.text("No value selected", id="match-default"), + ), + id="match-container", + ), + ) + + app = rx.App() + app.add_page(index) + + +@pytest.fixture(scope="module") +def cond_match_app( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[AppHarness, None, None]: + """Create a harness for the cond/match regression app. + + Args: + tmp_path_factory: Pytest fixture for creating temporary directories. + + Yields: + Running AppHarness for the test app. + """ + with AppHarness.create( + root=tmp_path_factory.mktemp("cond_match_app"), + app_source=CondMatchApp, + ) as harness: + yield harness + + +def test_cond_and_match_render_only_selected_branch( + cond_match_app: AppHarness, page: Page +): + """Cond and Match should render exactly one active branch per state value. + + Args: + cond_match_app: Running harness for the cond/match app. + page: Playwright page. + """ + assert cond_match_app.frontend_url is not None + page.goto(cond_match_app.frontend_url) + + expect(page.locator("#current-value")).to_have_text("A") + expect(page.locator("#cond-true")).to_have_text("A") + expect(page.locator("#cond-false")).to_have_count(0) + expect(page.locator("#match-a")).to_have_text("A is selected") + expect(page.locator("#match-b")).to_have_count(0) + expect(page.locator("#match-default")).to_have_count(0) + + page.click("#select-b") + expect(page.locator("#current-value")).to_have_text("B") + expect(page.locator("#cond-true")).to_have_count(0) + expect(page.locator("#cond-false")).to_have_text("B") + expect(page.locator("#match-a")).to_have_count(0) + expect(page.locator("#match-b")).to_have_text("B is selected") + expect(page.locator("#match-default")).to_have_count(0) + + page.click("#select-c") + expect(page.locator("#current-value")).to_have_text("C") + expect(page.locator("#cond-true")).to_have_count(0) + expect(page.locator("#cond-false")).to_have_text("B") + expect(page.locator("#match-a")).to_have_count(0) + expect(page.locator("#match-b")).to_have_count(0) + expect(page.locator("#match-default")).to_have_text("No value selected") diff --git a/tests/integration/tests_playwright/test_memoize_edge_cases.py b/tests/integration/tests_playwright/test_memoize_edge_cases.py new file mode 100644 index 00000000000..0d4e92e5f9c --- /dev/null +++ b/tests/integration/tests_playwright/test_memoize_edge_cases.py @@ -0,0 +1,209 @@ +"""Integration tests for auto-memoization edge cases. + +These exercise components whose memoization needs special care: + +- Snapshot boundaries (``recursive=False``) such as ``AccordionTrigger`` whose + state-dependent logic lives in a descendant. Without the snapshot wrapper + the cond's state read leaks into the page module and the trigger fails to + update on state transitions. +- HTML elements with constrained content models (``<title>``, ``<meta>``, + ``<style>``, ``<script>``). Independent memoization of a stateful ``Bare`` + child renders ``jsx("title", {}, jsx(Bare_xxx, {}))`` — React stringifies + the component child as ``[object Object]`` (or refuses to render at all + for void elements). Snapshot-wrapping keeps the Bare a text interpolation + inside the parent's body. + +Test design notes: +- The page title is supplied via ``app.add_page(..., title=MemoState.title_marker)`` + so the dynamic value flows through the standard React Router metadata path + and shows up in ``document.title``. +- Style content is matched on a unique marker substring rather than common + selectors like ``body`` (which conflicts with Emotion/Sonner stylesheets). +- ``<textarea>``'s runtime value semantics belong to React (children are + initial-value-only); the no-Bare-component-child invariant is verified by + the unit tests instead. +""" + +from collections.abc import Generator + +import pytest +from playwright.sync_api import Page, expect + +from reflex.testing import AppHarness + + +def MemoEdgeCasesApp(): + """App exercising memoization edge cases.""" + import reflex as rx + + class MemoState(rx.State): + is_open: bool = False + title_marker: str = "memo-title-home" + css_marker: str = "memo-css-light" + counter: int = 0 + + @rx.event + def toggle_open(self): + self.is_open = not self.is_open + + @rx.event + def set_title_about(self): + self.title_marker = "memo-title-about" + + @rx.event + def set_css_dark(self): + self.css_marker = "memo-css-dark" + + @rx.event + def bump(self): + self.counter = self.counter + 1 + + def index(): + return rx.box( + rx.el.style("body { --memo-marker: " + MemoState.css_marker + "; }"), + rx.box( + rx.button("toggle", on_click=MemoState.toggle_open, id="toggle"), + rx.button("title", on_click=MemoState.set_title_about, id="set-title"), + rx.button("css", on_click=MemoState.set_css_dark, id="set-css"), + rx.button("bump", on_click=MemoState.bump, id="bump"), + ), + rx.accordion.root( + rx.accordion.item( + header=rx.accordion.header( + rx.accordion.trigger( + rx.cond( + MemoState.is_open, + rx.text("Hide", id="trigger-hide"), + rx.text("Show", id="trigger-show"), + ), + id="accordion-trigger", + ), + ), + content=rx.accordion.content(rx.text("body")), + value="item-1", + ), + ), + rx.text(MemoState.counter, id="counter"), + ) + + app = rx.App() + app.add_page(index, title=MemoState.title_marker) + + +@pytest.fixture(scope="module") +def memo_app( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[AppHarness, None, None]: + """Run the memoization edge-cases app under an AppHarness. + + Args: + tmp_path_factory: Pytest fixture for creating temporary directories. + + Yields: + The running harness. + """ + with AppHarness.create( + root=tmp_path_factory.mktemp("memo_edge_cases"), + app_source=MemoEdgeCasesApp, + ) as harness: + yield harness + + +def test_accordion_trigger_with_stateful_cond_updates( + memo_app: AppHarness, page: Page +) -> None: + """AccordionTrigger holding a stateful cond updates on state changes. + + Args: + memo_app: Running app harness. + page: Playwright page. + """ + assert memo_app.frontend_url is not None + page.goto(memo_app.frontend_url) + + expect(page.locator("#trigger-show")).to_have_text("Show") + expect(page.locator("#trigger-hide")).to_have_count(0) + + page.click("#toggle") + expect(page.locator("#trigger-hide")).to_have_text("Hide") + expect(page.locator("#trigger-show")).to_have_count(0) + + # Bumping an unrelated counter must not desync the trigger render. + page.click("#bump") + expect(page.locator("#counter")).to_have_text("1") + expect(page.locator("#trigger-hide")).to_have_text("Hide") + + page.click("#toggle") + expect(page.locator("#trigger-show")).to_have_text("Show") + + +def _document_contains_style(page: Page, marker: str) -> bool: + """Whether any ``<style>`` element's text content contains ``marker``. + + ``<style>`` content is not "visible" text, so the Locator ``has_text`` + filter skips it. Inspect text content via JS instead. + + Args: + page: Playwright page. + marker: Substring to look for in style element text content. + + Returns: + True if any ``<style>`` element's textContent contains the marker. + """ + return page.evaluate( + """(marker) => { + const els = document.querySelectorAll('style'); + return Array.from(els).some(el => (el.textContent || '').includes(marker)); + }""", + marker, + ) + + +def test_page_title_updates_with_state(memo_app: AppHarness, page: Page) -> None: + """The page title (passed to ``add_page(title=...)``) tracks state. + + Verifying via ``document.title`` proves the state value flows through the + standard page-metadata path and lands as the title's text node, not as a + stringified JSX component child. + + Args: + memo_app: Running app harness. + page: Playwright page. + """ + assert memo_app.frontend_url is not None + page.goto(memo_app.frontend_url) + page.wait_for_selector("#trigger-show") + + expect(page).to_have_title("memo-title-home") + + page.click("#set-title") + expect(page).to_have_title("memo-title-about") + + +def test_style_element_renders_stateful_css_as_text( + memo_app: AppHarness, page: Page +) -> None: + """``rx.el.style(state_var)`` writes the state value as the stylesheet text. + + Uses a unique marker substring so the test does not collide with Emotion + or Sonner stylesheets that also live in the document. + + Args: + memo_app: Running app harness. + page: Playwright page. + """ + assert memo_app.frontend_url is not None + page.goto(memo_app.frontend_url) + page.wait_for_selector("#trigger-show") + + assert _document_contains_style(page, "memo-css-light") + assert not _document_contains_style(page, "memo-css-dark") + + page.click("#set-css") + page.wait_for_function( + """() => Array.from(document.querySelectorAll('style')) + .some(el => (el.textContent || '').includes('memo-css-dark'))""", + timeout=5000, + ) + assert _document_contains_style(page, "memo-css-dark") + assert not _document_contains_style(page, "memo-css-light") diff --git a/tests/units/compiler/test_compiler.py b/tests/units/compiler/test_compiler.py index 47a82c08400..e2f1b769668 100644 --- a/tests/units/compiler/test_compiler.py +++ b/tests/units/compiler/test_compiler.py @@ -5,6 +5,7 @@ import pytest from pytest_mock import MockerFixture from reflex_base import constants +from reflex_base.components.dynamic import bundle_library, reset_bundled_libraries from reflex_base.constants.compiler import PageNames from reflex_base.utils.imports import ImportVar, ParsedImportDict from reflex_base.vars.base import Var @@ -14,6 +15,7 @@ from reflex_components_core.el.elements.metadata import Head, Link, Meta from reflex_components_core.el.elements.other import Html +import reflex as rx from reflex.compiler import compiler, utils @@ -164,7 +166,6 @@ def test_compile_stylesheets(tmp_path: Path, mocker: MockerFixture): ( "@layer __reflex_base;\n" "@import url('./__reflex_style_reset.css');\n" - "@import url('@radix-ui/themes/styles.css');\n" "@import url('https://fonts.googleapis.com/css?family=Sofia&effect=neon|outline|emboss|shadow-multiple');\n" "@import url('https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap.min.css');\n" "@import url('https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap-theme.min.css');\n" @@ -228,7 +229,6 @@ def test_compile_stylesheets_scss_sass(tmp_path: Path, mocker: MockerFixture): ( "@layer __reflex_base;\n" "@import url('./__reflex_style_reset.css');\n" - "@import url('@radix-ui/themes/styles.css');\n" "@import url('./style.css');\n" f"@import url('./{Path('preprocess') / Path('styles_a.css')!s}');\n" f"@import url('./{Path('preprocess') / Path('styles_b.css')!s}');" @@ -250,7 +250,6 @@ def test_compile_stylesheets_scss_sass(tmp_path: Path, mocker: MockerFixture): ( "@layer __reflex_base;\n" "@import url('./__reflex_style_reset.css');\n" - "@import url('@radix-ui/themes/styles.css');\n" "@import url('./style.css');\n" f"@import url('./{Path('preprocess') / Path('styles_a.css')!s}');\n" f"@import url('./{Path('preprocess') / Path('styles_b.css')!s}');" @@ -297,7 +296,7 @@ def test_compile_stylesheets_exclude_tailwind(tmp_path, mocker: MockerFixture): assert compiler.compile_root_stylesheet(stylesheets) == ( str(Path(".web") / "styles" / (PageNames.STYLESHEET_ROOT + ".css")), - "@layer __reflex_base;\n@import url('./__reflex_style_reset.css');\n@import url('@radix-ui/themes/styles.css');\n@import url('./style.css');", + "@layer __reflex_base;\n@import url('./__reflex_style_reset.css');\n@import url('./style.css');", ) @@ -336,10 +335,75 @@ def test_compile_stylesheets_no_reset(tmp_path: Path, mocker: MockerFixture): / "styles" / (PageNames.STYLESHEET_ROOT + ".css") ), - "@layer __reflex_base;\n@import url('@radix-ui/themes/styles.css');\n@import url('./style.css');", + "@layer __reflex_base;\n@import url('./style.css');", + ) + + +def test_compile_stylesheets_includes_radix_plugin( + tmp_path: Path, mocker: MockerFixture +): + """Explicit RadixThemesPlugin should add the Radix stylesheet import.""" + project = tmp_path / "test_project" + project.mkdir() + + assets_dir = project / "assets" + assets_dir.mkdir() + (assets_dir / "style.css").write_text(".root { color: red; }") + + config = mocker.Mock() + config.plugins = [rx.plugins.RadixThemesPlugin()] + mocker.patch("reflex.compiler.compiler.get_config", return_value=config) + mocker.patch("reflex.compiler.compiler.Path.cwd", return_value=project) + mocker.patch( + "reflex.compiler.compiler.get_web_dir", + return_value=project / constants.Dirs.WEB, + ) + mocker.patch( + "reflex.compiler.utils.get_web_dir", return_value=project / constants.Dirs.WEB + ) + + assert compiler.compile_root_stylesheet(["/style.css"]) == ( + str( + project + / constants.Dirs.WEB + / "styles" + / (PageNames.STYLESHEET_ROOT + ".css") + ), + "@layer __reflex_base;\n@import url('./__reflex_style_reset.css');\n@import url('@radix-ui/themes/styles.css');\n@import url('./style.css');", ) +def test_compile_app_root_omits_radix_window_library_by_default(): + """Apps without Radix should not import it in the app root.""" + reset_bundled_libraries() + + _, code = compiler.compile_app_root(rx.el.div("hello")) + + assert "@radix-ui/themes" not in code + + +def test_compile_app_root_includes_radix_window_library_when_bundled(): + """Bundled Radix libraries should be exposed to window.__reflex.""" + reset_bundled_libraries() + try: + bundle_library("@radix-ui/themes@3.3.0") + + _, code = compiler.compile_app_root(rx.el.div("hello")) + + assert 'import * as radix_ui_themes from "@radix-ui/themes";' in code + assert '"@radix-ui/themes": radix_ui_themes' in code + finally: + reset_bundled_libraries() + + +def test_compile_contexts_has_default_color_mode_context(): + """ColorModeContext should have a safe fallback value without Radix.""" + _, code = compiler.compile_contexts(None, None) + + assert "createContext({" in code + assert 'resolvedColorMode: defaultColorMode === "dark" ? "dark" : "light"' in code + + def test_compile_nonexistent_stylesheet(tmp_path, mocker: MockerFixture): """Test that an error is thrown for non-existent stylesheets. diff --git a/tests/units/compiler/test_dynamic_components_codegen.py b/tests/units/compiler/test_dynamic_components_codegen.py new file mode 100644 index 00000000000..fe859984346 --- /dev/null +++ b/tests/units/compiler/test_dynamic_components_codegen.py @@ -0,0 +1,107 @@ +"""Code generation tests for dynamic components.""" + +from pathlib import Path + +from reflex_base.utils import serializers + +import reflex as rx +from reflex.state import State + +STATE_JS_TEMPLATE = ( + Path(__file__).parents[3] + / "packages/reflex-base/src/reflex_base/.templates/web/utils/state.js" +) + + +def test_dynamic_component_codegen_wires_event_handlers() -> None: + """Dynamic component codegen should preserve backend event handlers.""" + state = State(_reflex_internal_init=True) # pyright: ignore[reportCallIssue] + component = rx.el.div( + rx.el.button("hydrate", on_click=State.set_is_hydrated(True)), + rx.el.span(state.is_hydrated), + rx.el.button("unhydrate", on_click=State.set_is_hydrated(False)), + ) + code = serializers.serialize(component) + + assert isinstance(code, str) + assert code.startswith("//__reflex_evaluate") + assert "const {Fragment,useContext,useEffect}" in code + assert "const {EventLoopContext} = window['__reflex'][\"$/utils/context\"]" in code + assert ( + "const {ReflexEvent,applyEventActions} = window['__reflex'][\"$/utils/state\"]" + in code + ) + assert "const [addEvents, connectErrors] = useContext(EventLoopContext);" in code + assert code.count("onClick:") == 2 + assert code.count("addEvents(") == 2 + assert code.count("ReflexEvent(") == 2 + assert ( + 'ReflexEvent("reflex___state____state.set_is_hydrated", ' + '({ ["value"] : true }), ({ }))' + ) in code + assert ( + 'ReflexEvent("reflex___state____state.set_is_hydrated", ' + '({ ["value"] : false }), ({ }))' + ) in code + + +def test_dynamic_component_codegen_wires_state_var_counter_events() -> None: + """Dynamic component codegen should preserve stateful counter event handlers.""" + + class DynamicCounterCodegenState(rx.State): + count: int = 0 + + @rx.event + def set_count(self, count: int): + """Set the counter value. + + Args: + count: The new counter value. + """ + self.count = count + + @rx.var + def counter_ui(self) -> rx.Component: + """Get a dynamic counter component. + + Returns: + The dynamic counter component. + """ + return rx.hstack( + rx.button( + "-", + on_click=DynamicCounterCodegenState.set_count(self.count - 1), + ), + rx.text(self.count, size="9"), + rx.button( + "+", + on_click=DynamicCounterCodegenState.set_count(self.count + 1), + ), + spacing="5", + justify="center", + ) + + state = DynamicCounterCodegenState(_reflex_internal_init=True) # pyright: ignore[reportCallIssue] + code = serializers.serialize(state.counter_ui) + + assert isinstance(code, str) + assert code.startswith("//__reflex_evaluate") + assert "RadixThemesFlex" in code + assert "RadixThemesButton" in code + assert "RadixThemesText" in code + assert 'justify:"center"' in code + assert 'gap:"5"' in code + assert "const {Fragment,useContext,useEffect}" in code + assert "const {EventLoopContext} = window['__reflex'][\"$/utils/context\"]" in code + assert ( + "const {ReflexEvent,applyEventActions} = window['__reflex'][\"$/utils/state\"]" + in code + ) + assert "const [addEvents, connectErrors] = useContext(EventLoopContext);" in code + assert code.count("onClick:") == 2 + assert code.count("addEvents(") == 2 + assert code.count("ReflexEvent(") == 2 + assert code.count(".set_count") == 2 + assert '({ ["count"] : -1 }), ({ })' in code + assert '({ ["count"] : 1 }), ({ })' in code + assert 'jsx(RadixThemesText, ({as:"p",size:"9"}), 0)' in code diff --git a/tests/units/compiler/test_memoize_plugin.py b/tests/units/compiler/test_memoize_plugin.py new file mode 100644 index 00000000000..18fc3b3e5dc --- /dev/null +++ b/tests/units/compiler/test_memoize_plugin.py @@ -0,0 +1,2003 @@ +# ruff: noqa: D101 + +import dataclasses +import re +from collections.abc import Callable +from pathlib import Path +from types import SimpleNamespace +from typing import Any, cast + +import pytest +from reflex_base.components.component import Component, field +from reflex_base.components.memoize_helpers import ( + MemoizationStrategy, + get_memoization_strategy, +) +from reflex_base.constants.compiler import MemoizationDisposition, MemoizationMode +from reflex_base.plugins import CompileContext, CompilerHooks, PageContext +from reflex_base.vars import VarData +from reflex_base.vars.base import LiteralVar, Var +from reflex_components_core.base.bare import Bare +from reflex_components_core.base.fragment import Fragment +from reflex_components_core.base.link import RawLink, ScriptTag +from reflex_components_core.el.elements.forms import BaseInput, Textarea +from reflex_components_core.el.elements.inline import Br, Wbr +from reflex_components_core.el.elements.media import ( + Area, + Desc, + Embed, + Img, + Source, + SvgStyle, + Track, +) +from reflex_components_core.el.elements.media import Script as SvgScript +from reflex_components_core.el.elements.media import Title as SvgTitle +from reflex_components_core.el.elements.metadata import Base, Link, Meta, StyleEl, Title +from reflex_components_core.el.elements.scripts import Noscript, Script +from reflex_components_core.el.elements.tables import Col +from reflex_components_core.el.elements.typography import Hr + +import reflex as rx +import reflex.compiler.plugins.memoize as memoize_plugin +from reflex.compiler.plugins import DefaultCollectorPlugin, default_page_plugins +from reflex.compiler.plugins.memoize import MemoizeStatefulPlugin, _should_memoize +from reflex.experimental.memo import ( + ExperimentalMemoComponent, + create_passthrough_component_memo, +) +from reflex.state import BaseState + +STATE_VAR = LiteralVar.create("value")._replace( + merge_var_data=VarData(hooks={"useTestState": None}, state="TestState") +) + + +class Plain(Component): + tag = "Plain" + library = "plain-lib" + + +class WithProp(Component): + tag = "WithProp" + library = "with-prop-lib" + + label: Var[str] = field(default=LiteralVar.create("")) + + +class LeafComponent(Component): + tag = "LeafComponent" + library = "leaf-lib" + _memoization_mode = MemoizationMode(recursive=False) + + +class SpecialFormMemoState(BaseState): + items: list[str] = ["a"] + flag: bool = True + value: str = "a" + + +@dataclasses.dataclass(slots=True) +class FakePage: + route: str + component: Callable[[], Component] + title: Any = None + description: Any = None + image: str = "" + meta: tuple[dict[str, Any], ...] = () + + +def _compile_single_page( + component_factory: Callable[[], Component], +) -> tuple[CompileContext, PageContext]: + ctx = CompileContext( + pages=[FakePage(route="/p", component=component_factory)], + hooks=CompilerHooks(plugins=default_page_plugins()), + ) + with ctx: + ctx.compile() + return ctx, ctx.compiled_pages["/p"] + + +def test_should_memoize_catches_direct_state_var_in_prop() -> None: + """A component whose own prop carries state VarData should memoize.""" + comp = WithProp.create(label=STATE_VAR) + assert _should_memoize(comp) + + +def test_should_not_memoize_state_var_in_child_bare() -> None: + """A component whose Bare child contains state VarData should memoize.""" + comp = Plain.create(STATE_VAR) + assert not _should_memoize(comp) + + +def test_should_not_memoize_plain_component() -> None: + """A component with no state vars and no event triggers is not memoized.""" + comp = Plain.create(LiteralVar.create("static-content")) + assert not _should_memoize(comp) + + +def test_should_memoize_state_var_in_child_cond() -> None: + """A Bare containing state VarData should memoize.""" + comp = Bare.create(STATE_VAR) + assert _should_memoize(comp) + + +def test_should_not_memoize_prop_var_with_imports_only_var_data() -> None: + """Prop Vars carrying only imports (no state/hooks) must not trigger memoize. + + Regression: a ``class_name`` produced by the ``cn`` helper (clsx-for-tailwind) + has VarData with non-empty ``imports`` but empty ``state`` and ``hooks``; + snapshot-boundary elements like ``<textarea>`` were being wrapped in memo + purely because of that helper import. + """ + from reflex_base.utils.imports import ImportVar + + import_only_var = LiteralVar.create("static-class")._replace( + merge_var_data=VarData( + imports={"clsx-for-tailwind": [ImportVar(tag="cn")]}, + ) + ) + comp = WithProp.create(label=import_only_var) + assert not _should_memoize(comp) + # Snapshot-boundary form of the same: a Textarea whose only stateful-looking + # signal is an import-bearing class_name should not be memoized either. + boundary = Textarea.create(class_name=import_only_var, name="x") + assert not _should_memoize(boundary) + # And the Bare-contents short-circuit must use the same predicate: a Bare + # wrapping a Var with import-only var_data must not be memoized. + bare = Bare.create(import_only_var) + assert not _should_memoize(bare) + + +def test_should_not_memoize_when_disposition_never() -> None: + """``MemoizationDisposition.NEVER`` overrides heuristic eligibility.""" + comp = Plain.create(STATE_VAR) + object.__setattr__( + comp, + "_memoization_mode", + dataclasses.replace( + comp._memoization_mode, disposition=MemoizationDisposition.NEVER + ), + ) + assert not _should_memoize(comp) + + +def test_memoize_wrapper_uses_experimental_memo_component_and_call_site() -> None: + """Memoizable component imports a generated ``rx._x.memo`` wrapper.""" + ctx, page_ctx = _compile_single_page(lambda: Plain.create(STATE_VAR)) + + assert len(ctx.memoize_wrappers) == 1 + wrapper_tag = next(iter(ctx.memoize_wrappers)) + assert wrapper_tag in ctx.auto_memo_components + output = page_ctx.output_code or "" + assert f'import {{{wrapper_tag}}} from "$/utils/components/{wrapper_tag}"' in output + assert f"jsx({wrapper_tag}," in (page_ctx.output_code or "") + assert f"const {wrapper_tag} = memo" not in output + + +def test_memoize_wrapper_deduped_across_repeated_subtrees() -> None: + """Two identical memoizable call-sites collapse to one memo definition.""" + ctx, page_ctx = _compile_single_page( + lambda: Fragment.create( + Plain.create(STATE_VAR), + Plain.create(STATE_VAR), + ) + ) + assert len(ctx.memoize_wrappers) == 1 + wrapper_tag = next(iter(ctx.memoize_wrappers)) + assert list(ctx.auto_memo_components) == [wrapper_tag] + assert (page_ctx.output_code or "").count( + f'import {{{wrapper_tag}}} from "$/utils/components/{wrapper_tag}"' + ) == 1 + + +@pytest.mark.parametrize( + ("special_form", "body_marker"), + [ + ("foreach", "Array.prototype.map.call"), + ], +) +def test_special_form_memo_wrappers_render_structural_body( + special_form: str, + body_marker: str, +) -> None: + """Generated memo wrappers for special forms render the structural body. + + The memo body must subscribe to the state the special form references + (via ``useContext(StateContexts...)``), and the page must not — otherwise + the state-dependent render has leaked into page scope. + """ + from reflex.compiler.compiler import compile_memo_components + + def special_child() -> Component: + if special_form == "foreach": + return rx.foreach( + SpecialFormMemoState.items, + lambda item: rx.text(item), + ) + if special_form == "cond": + return cast( + Component, + rx.cond( + SpecialFormMemoState.flag, + rx.text("yes"), + rx.text("no"), + ), + ) + return cast( + Component, + rx.match( + SpecialFormMemoState.value, + ("a", rx.text("A")), + rx.text("default"), + ), + ) + + ctx, page_ctx = _compile_single_page(lambda: rx.box(special_child())) + + memo_files, _memo_imports = compile_memo_components( + components=(), + experimental_memos=tuple(ctx.auto_memo_components.values()), + ) + memo_code = "\n".join(code for _, code in memo_files) + + state_wiring = "useContext(StateContexts" + page_output = page_ctx.output_code + assert page_output is not None + assert state_wiring in memo_code + assert state_wiring not in page_output + assert body_marker in memo_code + assert body_marker not in page_output + + +def test_common_memoization_snapshot_helper_classifies_snapshot_cases() -> None: + """The shared memoization strategy classifies structural render forms.""" + from reflex_components_core.core.cond import Cond + from reflex_components_core.core.match import Match + from reflex_components_core.el.elements.forms import Form, Input + + foreach_parent = rx.box( + rx.foreach( + SpecialFormMemoState.items, + lambda item: rx.text(item), + ) + ) + cond_fragment = cast( + Component, + rx.cond( + SpecialFormMemoState.flag, + rx.text("yes"), + rx.text("no"), + ), + ) + match_fragment = cast( + Component, + rx.match( + SpecialFormMemoState.value, + ("a", rx.text("A")), + rx.text("default"), + ), + ) + + assert get_memoization_strategy(foreach_parent) is MemoizationStrategy.SNAPSHOT + assert get_memoization_strategy(cond_fragment) is MemoizationStrategy.PASSTHROUGH + # Cond and Match now use passthrough so branch JSX renders on the page side + # and the memo body just selects via children[i] indexing. + assert isinstance(cond_fragment.children[0], Cond) + assert ( + get_memoization_strategy(cond_fragment.children[0]) + is MemoizationStrategy.PASSTHROUGH + ) + assert isinstance(match_fragment.children[0], Match) + assert ( + get_memoization_strategy(match_fragment.children[0]) + is MemoizationStrategy.PASSTHROUGH + ) + assert ( + get_memoization_strategy(LeafComponent.create(Plain.create())) + is MemoizationStrategy.SNAPSHOT + ) + + form = Form.create(Input.create(name="username", id="username")) + assert get_memoization_strategy(form) is MemoizationStrategy.PASSTHROUGH + + +def test_generated_memo_component_is_not_itself_memoized() -> None: + """The generated memo component instance itself is skipped by the heuristic.""" + wrapper_factory, _definition = create_passthrough_component_memo( + "MyTag", Fragment.create() + ) + wrapper = wrapper_factory(Plain.create()) + assert isinstance(wrapper, ExperimentalMemoComponent) + assert not _should_memoize(wrapper) + + +def test_event_trigger_memoization_not_emit_usecallback_in_page_hooks() -> None: + """Components with event triggers do not get useCallback wrappers at the page level.""" + from reflex_base.event import EventChain + + # Construct an event chain referencing state so _get_memoized_event_triggers + # emits a useCallback. + event_var = Var(_js_expr="test_event")._replace( + _var_type=EventChain, + merge_var_data=VarData(state="TestState"), + ) + comp = Plain.create() + comp.event_triggers["on_click"] = event_var + + _ctx, page_ctx = _compile_single_page(lambda: comp) + + # Check that a useCallback hook line was added to the page hooks dict. + hook_lines = list(page_ctx.hooks.keys()) + assert not any( + "useCallback" in hook_line and "on_click_" in hook_line + for hook_line in hook_lines + ), f"Expected no on_click useCallback hook in {hook_lines!r}" + + +def test_generated_memo_component_renders_as_its_exported_tag() -> None: + """The generated experimental memo component renders as its exported tag.""" + wrapper_factory, definition = create_passthrough_component_memo( + "MyWrapper_abc", Fragment.create() + ) + wrapper = wrapper_factory(Plain.create()) + assert isinstance(wrapper, ExperimentalMemoComponent) + assert wrapper.tag == "MyWrapper_abc" + assert definition.export_name == "MyWrapper_abc" + assert wrapper.render()["name"] == "MyWrapper_abc" + + +def test_passthrough_memo_definitions_are_not_shared_globally(monkeypatch) -> None: + """Repeated tags across compiles rebuild their passthrough definitions. + + Regression: sharing auto-memo definitions globally by tag leaks the first + app's captured component tree into later compiles, which can stale-bind + state event names across AppHarness apps. + """ + tag = "SharedMemoTag" + first_component = Plain.create(STATE_VAR) + second_component = Plain.create(STATE_VAR) + + monkeypatch.setattr(memoize_plugin, "_compute_memo_tag", lambda comp: tag) + monkeypatch.setattr( + memoize_plugin, + "fix_event_triggers_for_memo", + lambda comp, page_context: comp, + ) + + def fake_create_passthrough_component_memo( + export_name: str, + component: Component, + ): + definition = SimpleNamespace(export_name=export_name, component=component) + return (lambda definition=definition: definition), definition + + monkeypatch.setattr( + memoize_plugin, + "create_passthrough_component_memo", + fake_create_passthrough_component_memo, + ) + + first_compile = SimpleNamespace(memoize_wrappers={}, auto_memo_components={}) + second_compile = SimpleNamespace(memoize_wrappers={}, auto_memo_components={}) + page_context = cast(PageContext, SimpleNamespace()) + + MemoizeStatefulPlugin._build_wrapper( + first_component, + page_context=page_context, + compile_context=first_compile, + ) + MemoizeStatefulPlugin._build_wrapper( + second_component, + page_context=page_context, + compile_context=second_compile, + ) + + first_definition = first_compile.auto_memo_components[tag] + second_definition = second_compile.auto_memo_components[tag] + assert first_definition.component is first_component + assert second_definition.component is second_component + assert second_definition is not first_definition + + +def test_shared_subtree_across_pages_uses_same_tag() -> None: + """The same memoizable subtree on multiple pages gets one shared tag.""" + ctx = CompileContext( + pages=[ + FakePage(route="/a", component=lambda: Plain.create(STATE_VAR)), + FakePage(route="/b", component=lambda: Plain.create(STATE_VAR)), + ], + hooks=CompilerHooks(plugins=default_page_plugins()), + ) + with ctx: + ctx.compile() + + assert len(ctx.memoize_wrappers) == 1 + tag = next(iter(ctx.memoize_wrappers)) + assert list(ctx.auto_memo_components) == [tag] + for route in ("/a", "/b"): + output = ctx.compiled_pages[route].output_code or "" + assert f'import {{{tag}}} from "$/utils/components/{tag}"' in output + assert f"jsx({tag}," in output + + +def test_shared_parent_instance_across_pages_preserves_original() -> None: + """A parent instance reused across pages must not have its children rebound. + + Regression: the compile walker replaces memoizable descendants with memo + wrappers and writes the new children list onto their parent. If the parent + is the same Python object on two pages (e.g. a module-scope layout), page + A's compile would mutate page B's starting tree, producing a ``ReferenceError`` + for the memo tag on the second page. + """ + shared_parent = Fragment.create(WithProp.create(label=STATE_VAR)) + original_children = list(shared_parent.children) + original_child = shared_parent.children[0] + + ctx = CompileContext( + pages=[ + FakePage(route="/a", component=lambda: shared_parent), + FakePage(route="/b", component=lambda: shared_parent), + ], + hooks=CompilerHooks(plugins=default_page_plugins()), + ) + with ctx: + ctx.compile() + + assert shared_parent.children == original_children, ( + f"shared parent's children mutated: {shared_parent.children!r}" + ) + assert shared_parent.children[0] is original_child, ( + "shared parent's child reference replaced by a memo wrapper" + ) + + assert len(ctx.memoize_wrappers) == 1 + tag = next(iter(ctx.memoize_wrappers)) + for route in ("/a", "/b"): + output = ctx.compiled_pages[route].output_code or "" + assert f'import {{{tag}}} from "$/utils/components/{tag}"' in output, ( + f"route {route} missing memo tag import" + ) + assert f"jsx({tag}," in output, f"route {route} does not render the memo tag" + + +def test_shared_nested_parent_mirroring_common_elements_preserves_original() -> None: + """Deeper nested shape — mirrors ``common_elements`` in test_event_chain. + + ``common_elements`` is an outer ``rx.vstack`` that contains an inner + ``rx.vstack(rx.foreach(...))`` memoizable subtree. The walker must clone + the entire spine from the memoized descendant up to the shared root, not + just the immediate parent. + """ + inner_parent = Fragment.create(WithProp.create(label=STATE_VAR)) + shared_outer = Fragment.create( + WithProp.create(label=LiteralVar.create("static")), + inner_parent, + WithProp.create(label=LiteralVar.create("trailing")), + ) + original_outer_children = list(shared_outer.children) + original_inner = shared_outer.children[1] + original_inner_children = list(inner_parent.children) + original_innermost = inner_parent.children[0] + + ctx = CompileContext( + pages=[ + FakePage(route="/a", component=lambda: shared_outer), + FakePage(route="/b", component=lambda: shared_outer), + FakePage(route="/c", component=lambda: shared_outer), + ], + hooks=CompilerHooks(plugins=default_page_plugins()), + ) + with ctx: + ctx.compile() + + assert shared_outer.children == original_outer_children + assert shared_outer.children[1] is original_inner + assert inner_parent.children == original_inner_children + assert inner_parent.children[0] is original_innermost + + assert len(ctx.memoize_wrappers) == 1 + tag = next(iter(ctx.memoize_wrappers)) + for route in ("/a", "/b", "/c"): + output = ctx.compiled_pages[route].output_code or "" + assert f'import {{{tag}}} from "$/utils/components/{tag}"' in output + assert f"jsx({tag}," in output + + +def test_memoization_leaf_internal_hooks_do_not_leak_into_page() -> None: + """Hooks from a ``MemoizationLeaf``'s internal children stay in its memo body. + + ``MemoizationLeaf``-derived components (e.g. ``rx.upload.root``) build + internal machinery as their own structural children, attaching stateful + hooks via ``special_props``/``VarData``. Those hooks belong to the memo + component's function body — not to the page — because the whole point of + the leaf is to isolate its subtree from page-level re-renders. + + The test asserts both directions: the hook lines do not appear in the + page's collected hooks, *and* they do appear in the compiled memo module + (otherwise a regression that drops them entirely would pass the negative + check). + """ + from reflex_base.components.component import MemoizationLeaf + from reflex_base.event import EventChain + from reflex_base.vars.base import Var + + from reflex.compiler.compiler import compile_memo_components + + class StatefulLeaf(MemoizationLeaf): + tag = "StatefulLeaf" + library = "stateful-leaf-lib" + + @classmethod + def create(cls, *children, **props): + # Simulate what rx.upload.root does: build an internal child whose + # special_props carry stateful hook lines via VarData. + internal_hook_var = Var( + _js_expr="__internal_leaf_probe()", + _var_type=None, + _var_data=VarData( + hooks={ + "const __internal_leaf_probe = useLeafProbe();": None, + "const on_drop_xyz = useCallback(() => {}, []);": None, + }, + state="LeafState", + ), + ) + internal_child = Plain.create(*children) + internal_child.special_props = [internal_hook_var] + return super().create(internal_child, **props) + + stateful_event = Var(_js_expr="evt")._replace( + _var_type=EventChain, + merge_var_data=VarData(state="LeafState"), + ) + leaf = StatefulLeaf.create() + leaf.event_triggers["on_something"] = stateful_event + + ctx, page_ctx = _compile_single_page(lambda: leaf) + + page_hook_lines = list(page_ctx.hooks) + leaking_hooks = [ + hook + for hook in page_hook_lines + if "useLeafProbe" in hook or "on_drop_xyz" in hook + ] + assert not leaking_hooks, ( + f"MemoizationLeaf internal hooks leaked into page: {leaking_hooks!r}" + ) + + # The hooks must survive somewhere — in the compiled memo module for the + # generated leaf wrapper. Compile the auto-memo definitions collected + # during the page compile and check that the hook lines are present. + assert ctx.auto_memo_components, ( + "expected an auto-memo wrapper to be generated for the leaf" + ) + memo_files, _memo_imports = compile_memo_components( + components=(), + experimental_memos=tuple(ctx.auto_memo_components.values()), + ) + memo_code = "\n".join(code for _, code in memo_files) + assert "useLeafProbe" in memo_code, ( + "leaf's internal probe hook was dropped from the memo module" + ) + assert "on_drop_xyz" in memo_code, ( + "leaf's internal useCallback hook was dropped from the memo module" + ) + + +def test_plugin_only_registered_once_in_default_page_plugins() -> None: + """MemoizeStatefulPlugin appears exactly once in the default plugin pipeline.""" + plugins = default_page_plugins() + memoize_plugins = [p for p in plugins if isinstance(p, MemoizeStatefulPlugin)] + assert len(memoize_plugins) == 1 + # And it is registered after the DefaultCollectorPlugin. + collector_index = next( + i for i, p in enumerate(plugins) if isinstance(p, DefaultCollectorPlugin) + ) + memoize_index = plugins.index(memoize_plugins[0]) + assert memoize_index > collector_index + + +def test_match_non_stateful_cond_allows_stateful_children_to_memoize() -> None: + """Match with a non-stateful condition must not suppress child memoization. + + Regression: Match was a MemoizationLeaf, causing it to push onto the + suppressor stack when its condition had no VarData. That blocked + independently-stateful children from being wrapped. After the fix Match + is a plain Component and its stateful children are memoized normally. + """ + + def page() -> Component: + comp = rx.match( + "static", # non-stateful condition + ("a", WithProp.create(label=STATE_VAR)), + WithProp.create(label=LiteralVar.create("default")), + ) + assert isinstance(comp, Component) + return comp + + ctx, _page_ctx = _compile_single_page(page) + assert len(ctx.memoize_wrappers) == 1, ( + f"Expected the stateful WithProp inside match cases to be memoized, " + f"got wrappers: {list(ctx.memoize_wrappers)}" + ) + + +def test_cond_non_stateful_cond_allows_stateful_children_to_memoize() -> None: + """Cond with a non-stateful condition must not suppress child memoization. + + When the condition carries no VarData, Cond should not be extracted to its + own memo component. Its stateful children (comp1 / comp2) should still be + independently memoized. + """ + + def page() -> Component: + comp = rx.cond( + True, # non-stateful condition + WithProp.create(label=STATE_VAR), + WithProp.create(label=LiteralVar.create("false-branch")), + ) + assert isinstance(comp, Component) + return comp + + ctx, _page_ctx = _compile_single_page(page) + assert len(ctx.memoize_wrappers) == 1, ( + f"Expected the stateful WithProp inside cond branch to be memoized, " + f"got wrappers: {list(ctx.memoize_wrappers)}" + ) + + +def test_cond_and_match_strategy_classification() -> None: + """Cond and Match both use passthrough; branches render on the page side.""" + from reflex_components_core.core.cond import Cond + from reflex_components_core.core.match import Match + + cond_non_stateful = rx.cond( + True, + rx.text("yes"), + rx.text("no"), + ) + cond_stateful = rx.cond( + SpecialFormMemoState.flag, + rx.text("yes"), + rx.text("no"), + ) + match_non_stateful = rx.match( + "static", + ("a", rx.text("A")), + rx.text("default"), + ) + match_stateful = rx.match( + SpecialFormMemoState.value, + ("a", rx.text("A")), + rx.text("default"), + ) + + for comp in (cond_non_stateful, cond_stateful): + assert isinstance(comp, Component) + assert get_memoization_strategy(comp) is MemoizationStrategy.PASSTHROUGH + assert isinstance(comp.children[0], Cond) + assert ( + get_memoization_strategy(comp.children[0]) + is MemoizationStrategy.PASSTHROUGH + ) + + for comp in (match_non_stateful, match_stateful): + assert isinstance(comp, Component) + assert isinstance(comp.children[0], Match) + assert ( + get_memoization_strategy(comp.children[0]) + is MemoizationStrategy.PASSTHROUGH + ) + + +def test_cond_stateful_var_branch_memoized_as_bare() -> None: + """rx.cond(True, STATE_VAR, "false") embeds a stateful ternary Var in a Bare. + + The ternary Var produced by the Var-returning cond path carries STATE_VAR's + VarData. When rendered inside rx.box it appears as a Bare child, which must + be extracted into its own memoized component. + """ + ctx, _page_ctx = _compile_single_page( + lambda: rx.box(rx.cond(True, STATE_VAR, "false")), + ) + assert len(ctx.memoize_wrappers) == 1, ( + f"Expected stateful cond ternary var to produce one memoized Bare, " + f"got wrappers: {list(ctx.memoize_wrappers)}" + ) + + +def test_cond_stateful_condition_memoizes_whole_cond_and_stateful_branch() -> None: + """Stateful Cond condition memoizes both Cond and stateful branch. + + Cond should recurse into branches so stateful branch components are wrapped + independently, while the Cond itself is also wrapped because its condition + var reads state. + """ + + def page() -> Component: + comp = rx.cond( + SpecialFormMemoState.flag, + WithProp.create(label=STATE_VAR), + WithProp.create(label=LiteralVar.create("false-branch")), + ) + assert isinstance(comp, Component) + return comp + + ctx, _page_ctx = _compile_single_page(page) + + assert len(ctx.memoize_wrappers) == 2, ( + "Expected both Cond and its stateful branch component to be memoized, " + f"got wrappers: {list(ctx.memoize_wrappers)}" + ) + wrapper_tags = tuple(ctx.memoize_wrappers) + assert any("cond" in tag.lower() for tag in wrapper_tags) + assert any("withprop" in tag.lower() for tag in wrapper_tags) + + +def test_match_stateful_condition_memoizes_whole_match_and_stateful_branch() -> None: + """Stateful Match condition memoizes both Match and stateful branch. + + Match should recurse into branches so stateful branch components are + memoized independently, while Match itself is memoized when its condition + var carries VarData. + """ + + def page() -> Component: + comp = rx.match( + SpecialFormMemoState.value, + ("a", WithProp.create(label=STATE_VAR)), + WithProp.create(label=LiteralVar.create("default")), + ) + assert isinstance(comp, Component) + return comp + + ctx, _page_ctx = _compile_single_page(page) + assert len(ctx.memoize_wrappers) == 2, ( + "Expected both Match and its stateful branch component to be memoized, " + f"got wrappers: {list(ctx.memoize_wrappers)}" + ) + wrapper_tags = tuple(ctx.memoize_wrappers) + assert any("match" in tag.lower() for tag in wrapper_tags) + assert any("withprop" in tag.lower() for tag in wrapper_tags) + + +def test_cond_stateful_branch_component_renders_via_memoized_wrapper() -> None: + """Components inside Cond branches must render via their memo wrappers. + + Regression shape matching the Match case: when the walker memoizes a + branch component, Cond rendering must use the wrapped branch tag in page + output rather than the original unwrapped component tag. + """ + + def page() -> Component: + comp = rx.cond( + True, + WithProp.create(label=STATE_VAR), + WithProp.create(label=LiteralVar.create("false-branch")), + ) + assert isinstance(comp, Component) + return comp + + ctx, page_ctx = _compile_single_page(page) + assert len(ctx.memoize_wrappers) == 1, ( + f"Expected stateful branch to produce one memo wrapper, got: {list(ctx.memoize_wrappers)}" + ) + wrapper_tag = next(iter(ctx.memoize_wrappers)) + output = page_ctx.output_code or "" + assert f"jsx({wrapper_tag}," in output, ( + f"Memo wrapper {wrapper_tag!r} not found in page output.\n" + f"Output snippet: {output[:2000]}" + ) + + +def test_cond_stateful_condition_renders_branch_logic_in_memo_body() -> None: + """Stateful Cond memo body must select both branches via ``children`` indexing. + + Cond is now a passthrough wrapper: branch JSX is rendered on the page side + and passed as the ``children`` array. The memo body's ternary must select + ``children[0]`` for the true branch and ``children[1]`` for the false + branch — neither branch should collapse to a generic ``? children`` hole + nor inline the original branch text into the memo body. + """ + from reflex.compiler.compiler import compile_memo_components + + def page() -> Component: + comp = rx.cond( + SpecialFormMemoState.flag, + rx.text("yes"), + rx.text("no"), + ) + assert isinstance(comp, Component) + return comp + + ctx, page_ctx = _compile_single_page(page) + assert len(ctx.memoize_wrappers) == 1, ( + f"Expected stateful Cond to produce one memo wrapper, got: {list(ctx.memoize_wrappers)}" + ) + + memo_files, _memo_imports = compile_memo_components( + components=(), + experimental_memos=tuple(ctx.auto_memo_components.values()), + ) + memo_code = "\n".join(code for _, code in memo_files) + + assert "children?.at?.(0)" in memo_code, ( + "Cond memo body should select the true branch via children[0].\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + assert "children?.at?.(1)" in memo_code, ( + "Cond memo body should select the false branch via children[1].\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + assert '"yes"' not in memo_code, ( + "Cond memo body unexpectedly inlined the true branch.\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + assert '"no"' not in memo_code, ( + "Cond memo body unexpectedly inlined the false branch.\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + + page_output = page_ctx.output_code or "" + assert '"yes"' in page_output, ( + "Page output should render the true branch as a memo wrapper child.\n" + f"Page output snippet: {page_output[:2000]}" + ) + assert '"no"' in page_output, ( + "Page output should render the false branch as a memo wrapper child.\n" + f"Page output snippet: {page_output[:2000]}" + ) + + +def test_match_stateful_branch_component_renders_via_memoized_wrapper() -> None: + """Components inside Match branches must be rendered via their memo wrappers. + + Regression: Match._render() used self.match_cases / self.default directly + instead of self.children. The walker updates children when it memoizes a + branch component, but those updates were invisible to Match's render, so + the generated page JSX still referenced the original unwrapped component + tag rather than the memo wrapper. + """ + + def page() -> Component: + comp = rx.match( + "static", + ("a", WithProp.create(label=STATE_VAR)), + WithProp.create(label=LiteralVar.create("default")), + ) + assert isinstance(comp, Component) + return comp + + ctx, page_ctx = _compile_single_page(page) + assert len(ctx.memoize_wrappers) == 1, ( + f"Expected stateful branch to produce one memo wrapper, got: {list(ctx.memoize_wrappers)}" + ) + wrapper_tag = next(iter(ctx.memoize_wrappers)) + output = page_ctx.output_code or "" + assert f"jsx({wrapper_tag}," in output, ( + f"Memo wrapper {wrapper_tag!r} not found in page output.\n" + f"Output snippet: {output[:2000]}" + ) + + +def test_match_stateful_condition_uses_memoized_branch_wrapper_in_memo_body() -> None: + """Stateful Match passes branch wrappers as page-side children. + + Match is now a passthrough wrapper: when both the match condition and a + branch are stateful, the Match wrapper itself is memoized and the branch + is memoized separately. The Match memo body selects via ``children[i]`` + indexing, and the page output renders the branch wrapper as a child of + the Match wrapper (rather than inlining the unwrapped branch component). + """ + from reflex.compiler.compiler import compile_memo_components + + def page() -> Component: + comp = rx.match( + SpecialFormMemoState.value, + ("a", WithProp.create(label=STATE_VAR)), + WithProp.create(label=LiteralVar.create("default")), + ) + assert isinstance(comp, Component) + return comp + + ctx, page_ctx = _compile_single_page(page) + assert len(ctx.memoize_wrappers) == 2, ( + "Expected both Match and its stateful branch component to be memoized, " + f"got wrappers: {list(ctx.memoize_wrappers)}" + ) + + match_wrapper_tag = next( + tag for tag in ctx.memoize_wrappers if "match" in tag.lower() + ) + branch_wrapper_tag = next( + tag for tag in ctx.memoize_wrappers if "withprop" in tag.lower() + ) + + memo_files, _memo_imports = compile_memo_components( + components=(), + experimental_memos=tuple(ctx.auto_memo_components.values()), + ) + match_memo_code = next( + code + for path, code in memo_files + if Path(path).name == f"{match_wrapper_tag}.jsx" + ) + + assert "children?.at?.(0)" in match_memo_code, ( + "Match memo body should select case 0 via children indexing.\n" + f"Memo code snippet: {match_memo_code[:2000]}" + ) + assert "children?.at?.(1)" in match_memo_code, ( + "Match memo body should select the default via children indexing.\n" + f"Memo code snippet: {match_memo_code[:2000]}" + ) + assert f"jsx({branch_wrapper_tag}," not in match_memo_code, ( + "Match memo body should not inline the branch wrapper; the branch " + "renders on the page side as a memo wrapper child.\n" + f"Memo code snippet: {match_memo_code[:2000]}" + ) + + page_output = page_ctx.output_code or "" + assert f"jsx({match_wrapper_tag}," in page_output, ( + f"Page output should render the Match memo wrapper {match_wrapper_tag!r}.\n" + f"Output snippet: {page_output[:2000]}" + ) + assert f"jsx({branch_wrapper_tag}," in page_output, ( + f"Page output should render the branch memo wrapper {branch_wrapper_tag!r} " + "as a child of the Match wrapper.\n" + f"Output snippet: {page_output[:2000]}" + ) + + +def test_memoized_match_wrapper_receives_case_children_in_page_output() -> None: + """Passthrough Match wrapper receives all case children from the page output. + + With Match handled as a passthrough memo, the page renders each case's JSX + as a child of the Match wrapper. The memo body selects which child to mount + via ``children[i]`` indexing keyed on the (possibly stateful) condition. + """ + + def page() -> Component: + comp = rx.match( + SpecialFormMemoState.value, + ("a", rx.text("A")), + ("b", rx.text("B")), + rx.text("default"), + ) + assert isinstance(comp, Component) + return comp + + ctx, page_ctx = _compile_single_page(page) + assert len(ctx.memoize_wrappers) == 1, ( + f"Expected stateful Match to produce one memo wrapper, got: {list(ctx.memoize_wrappers)}" + ) + wrapper_tag = next(iter(ctx.memoize_wrappers)) + output = page_ctx.output_code or "" + + assert f"jsx({wrapper_tag}," in output, ( + f"Memo wrapper {wrapper_tag!r} not found in page output.\n" + f"Output snippet: {output[:2000]}" + ) + # Each case-return JSX, plus the default, must reach the wrapper as a child. + for case_text in ('"A"', '"B"', '"default"'): + assert case_text in output, ( + f"Expected case JSX {case_text} in page output as a Match wrapper child.\n" + f"Output snippet: {output[:2000]}" + ) + # Match wrapper must be called with three positional children (the cases plus + # default), not as an empty-children call. + assert re.search( + rf"jsx\({re.escape(wrapper_tag)},\s*\{{\}},\s*jsx\(", + output, + ), ( + "Match wrapper should receive case JSX as positional children in page output.\n" + f"Output snippet: {output[:2000]}" + ) + + +def test_client_state_setter_in_call_function_event_imports_refs() -> None: + """A button whose ``on_click`` calls a global ``ClientStateVar`` setter + must memoize and the resulting memo body's imports must include ``refs`` + from ``$/utils/state``. + + Regression: ``ClientStateVar.set_value`` builds its setter as + ``refs['_client_state_<setter>']`` but the returned setter ``Var`` does not + carry the ``refs`` import. When the on_click event chain is compiled into + the memo body, the body references ``refs['_client_state_<setter>'](42)`` + with no matching ``import { refs } from "$/utils/state"`` — producing a + ``ReferenceError: refs is not defined`` at runtime. + """ + from reflex.compiler.compiler import compile_memo_components + from reflex.experimental.client_state import ClientStateVar + + counter = ClientStateVar.create("counter", default=0) + + def page() -> Component: + return rx.el.button( + "click", + on_click=rx.call_function(counter.set_value(42)), + ) + + ctx, _page_ctx = _compile_single_page(page) + + assert len(ctx.memoize_wrappers) == 1, ( + "Expected the button with a stateful on_click to be auto-memoized, " + f"got wrappers: {list(ctx.memoize_wrappers)}" + ) + wrapper_tag = next(iter(ctx.memoize_wrappers)) + + memo_files, _memo_imports = compile_memo_components( + components=(), + experimental_memos=tuple(ctx.auto_memo_components.values()), + ) + memo_code = next( + code for path, code in memo_files if Path(path).name == f"{wrapper_tag}.jsx" + ) + + assert "refs['_client_state_setCounter'](42)" in memo_code, ( + "Expected the memo body to call the client-state setter via refs.\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + + state_import_match = re.search( + r'^import\s*\{([^}]*)\}\s*from\s*"\$/utils/state"', + memo_code, + flags=re.MULTILINE, + ) + assert state_import_match is not None, ( + "Memo body must import from $/utils/state since the on_click handler " + "uses refs['_client_state_setCounter'].\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + imported_names = {name.strip() for name in state_import_match.group(1).split(",")} + assert "refs" in imported_names, ( + f"Memo body imports {imported_names!r} from $/utils/state but is missing " + "'refs' — the on_click handler references refs['_client_state_setCounter'].\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + + +def test_debounce_input_memo_renders_react_debounce_wrapper() -> None: + """``rx.input(value=..., on_change=..., debounce_timeout=N)`` memoizes via DebounceInput. + + When ``rx.input`` is given both ``value`` and ``on_change`` it is wrapped by + ``DebounceInput`` so the underlying input is fully controlled without typing + jank. The wrapper carries DebounceInput-known props (``debounce_timeout``, + ``input_ref``, ``element``) and also forwards the inner TextField as the + ``element`` prop. The memo body produced by the auto-memoize plugin must: + + - Import ``DebounceInput`` from ``react-debounce-input`` and render it via + ``jsx(DebounceInput, ...)`` rather than rendering the inner TextField + directly. The whole point of the wrapping is to give react-debounce-input + ownership of the keystroke pipeline; if the memo emitted the inner + ``TextField.Root`` instead, controlled-input updates would race the + backend round-trip and drop characters. + - Pass ``debounceTimeout`` as a real DebounceInput prop, not via ``css``. + Reflex routes unknown TextFieldRoot kwargs (like ``debounce_timeout``) + into ``style`` at component construction; ``DebounceInput.create`` then + copies ``child.style`` into the wrapper, which can leak the timeout into + the rendered ``css`` block. The timeout belongs on the wrapper as a real + prop — leaking it to ``css`` makes it a no-op styling key while the real + debounce behavior depends on the prop alone. + - Wire ``element`` to ``RadixThemesTextField.Root`` so the underlying input + is the radix text field and not a bare ``<input>``. + """ + from reflex.compiler.compiler import compile_memo_components + + class DebounceState(BaseState): + value: str = "" + + @rx.event + def set_value(self, v: str) -> None: + self.value = v + + def page() -> Component: + return rx.input( + id="my_input", + value=DebounceState.value, + on_change=DebounceState.set_value, + debounce_timeout=250, + ) + + ctx, _page_ctx = _compile_single_page(page) + + assert len(ctx.memoize_wrappers) == 1, ( + "Expected the controlled rx.input to memoize as a single DebounceInput " + f"wrapper, got: {list(ctx.memoize_wrappers)}" + ) + wrapper_tag = next(iter(ctx.memoize_wrappers)) + assert "debounceinput" in wrapper_tag.lower(), ( + f"Memo wrapper tag should be derived from DebounceInput, got: {wrapper_tag!r}" + ) + + memo_files, _memo_imports = compile_memo_components( + components=(), + experimental_memos=tuple(ctx.auto_memo_components.values()), + ) + memo_code = next( + code for path, code in memo_files if Path(path).name == f"{wrapper_tag}.jsx" + ) + + assert re.search( + r'^import\s+DebounceInput\s+from\s+"react-debounce-input"', + memo_code, + flags=re.MULTILINE, + ), ( + "Memo body must import DebounceInput from react-debounce-input.\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + assert "jsx(DebounceInput," in memo_code, ( + "Memo body must render via DebounceInput, not inline the inner TextField.\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + assert "debounceTimeout:250" in memo_code, ( + "Memo body must pass debounceTimeout as a DebounceInput prop.\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + assert "element:RadixThemesTextField.Root" in memo_code, ( + "Memo body must pass the radix TextField as DebounceInput's element prop.\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + + css_block_match = re.search( + r"css:\(\{([^}]*)\}\)", + memo_code, + ) + css_contents = css_block_match.group(1) if css_block_match else "" + assert "debounceTimeout" not in css_contents, ( + "debounceTimeout leaked into the css block — it should only be a " + "DebounceInput prop. Reflex routes unknown TextFieldRoot kwargs into " + "style, and DebounceInput.create copies child.style verbatim, so the " + "timeout ends up duplicated as a no-op CSS key.\n" + f"css block: {css_contents!r}\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + + +def test_should_memoize_snapshot_boundary_with_stateful_descendant() -> None: + """A snapshot boundary memoizes when its subtree contains state-derived hooks. + + ``LeafComponent`` mirrors the Radix-primitive shape: ``recursive=False`` + set directly without inheriting from ``MemoizationLeaf``. + """ + boundary = LeafComponent.create(Plain.create(STATE_VAR)) + assert _should_memoize(boundary) + + +def test_snapshot_boundary_wraps_subtree_once_when_descendant_is_stateful() -> None: + """A snapshot boundary with a stateful descendant produces exactly one wrapper. + + The boundary owns its subtree; descendants must remain suppressed. End + result: one snapshot wrapper covering the boundary, no independent wrapper + for the stateful descendant. + """ + ctx, page_ctx = _compile_single_page( + lambda: LeafComponent.create(Plain.create(STATE_VAR)) + ) + assert len(ctx.memoize_wrappers) == 1, ( + "Expected exactly one snapshot wrapper covering the leaf and its " + f"stateful descendant. Got: {list(ctx.memoize_wrappers)}" + ) + wrapper_tag = next(iter(ctx.memoize_wrappers)) + assert "leafcomponent" in wrapper_tag.lower(), ( + f"Wrapper should be derived from LeafComponent, got: {wrapper_tag!r}" + ) + output = page_ctx.output_code or "" + assert f"jsx({wrapper_tag}," in output + + +def test_snapshot_boundary_with_static_subtree_is_not_wrapped() -> None: + """A snapshot boundary with no stateful descendant emits no wrapper. + + Sanity check: the new rule fires on subtree state, not on the boundary + flag alone. Static leaves stay on the page as before. + """ + ctx, _page_ctx = _compile_single_page( + lambda: LeafComponent.create(Plain.create(LiteralVar.create("static"))) + ) + assert len(ctx.memoize_wrappers) == 0, ( + f"Expected no wrapper for a fully static boundary; got: {list(ctx.memoize_wrappers)}" + ) + + +def test_snapshot_boundary_with_event_trigger_descendant_is_wrapped() -> None: + """A snapshot boundary with a stateful event-trigger descendant must wrap.""" + from reflex_base.event import EventChain + + event_var = Var(_js_expr="test_event")._replace( + _var_type=EventChain, + merge_var_data=VarData(state="TestState", hooks={"useTestState": None}), + ) + inner = Plain.create(on_click=event_var) + boundary = LeafComponent.create(inner) + assert _should_memoize(boundary), ( + "Snapshot boundary with a stateful event-trigger descendant must memoize." + ) + + +def test_snapshot_boundary_with_no_arg_event_handler_descendant_is_wrapped() -> None: + """A boundary whose descendant has on_click without arg vars still wraps. + + No-arg handlers (``on_click=State.ping``) contribute to the page only + via the descendant's ``event_triggers`` and ``_get_events_hooks`` — the + per-Var subtree scan misses them. The reactive-data check must also + inspect ``event_triggers`` directly so the boundary wraps and the + callback's ``useCallback`` lands inside the snapshot body. + """ + inner = Plain.create() + inner.event_triggers["on_click"] = Var(_js_expr="evt") + boundary = LeafComponent.create(inner) + assert _should_memoize(boundary) + + +def test_title_with_stateful_var_child_does_not_wrap_bare_independently() -> None: + """``rx.el.title(state_var)`` must not produce a Bare component child. + + ``<title>`` is RCDATA — text content only. Wrapping the inner Bare as an + independent memo wrapper renders ``jsx("title", {}, jsx(Bare_xxx, {}))`` + which React refuses to interpolate as text. Marking ``Title`` as a + snapshot boundary keeps the Bare inside the title's snapshot, where it + renders as a text interpolation. + """ + title = Title.create(STATE_VAR) + ctx, page_ctx = _compile_single_page(lambda: title) + + assert len(ctx.memoize_wrappers) == 1, ( + "Expected exactly one snapshot wrapper for the title; got: " + f"{list(ctx.memoize_wrappers)}" + ) + wrapper_tag = next(iter(ctx.memoize_wrappers)) + assert wrapper_tag.lower().startswith("title_"), ( + f"Wrapper should be derived from Title, got: {wrapper_tag!r}" + ) + output = page_ctx.output_code + assert output is not None + assert f"jsx({wrapper_tag}," in output, ( + "The page output must call the snapshot wrapper.\n" + f"Page output snippet: {output[:2000]}" + ) + assert "useTestState" not in output, ( + "The state-bearing hook should live inside the memo body, not the page.\n" + f"Page output snippet: {output[:2000]}" + ) + assert "TestState" not in output, ( + "The state-context wiring should live inside the memo body, not the page.\n" + f"Page output snippet: {output[:2000]}" + ) + + +def test_meta_with_stateful_var_child_does_not_wrap_bare_independently() -> None: + """``rx.el.meta(state_var)`` must not produce a Bare component child. + + ``<meta>`` is a void element — it forbids any children at all. Memoizing + the Bare independently produces ``jsx("meta", {}, jsx(Bare_xxx, {}))`` + which is invalid HTML. + """ + meta = Meta.create(STATE_VAR) + ctx, page_ctx = _compile_single_page(lambda: meta) + + assert len(ctx.memoize_wrappers) == 1, ( + "Expected exactly one snapshot wrapper for the meta; got: " + f"{list(ctx.memoize_wrappers)}" + ) + wrapper_tag = next(iter(ctx.memoize_wrappers)) + output = page_ctx.output_code + assert output is not None + assert f"jsx({wrapper_tag}," in output, ( + "The page output must call the meta's snapshot wrapper.\n" + f"Page output snippet: {output[:2000]}" + ) + assert "useTestState" not in output, ( + "The state-bearing hook should live inside the memo body, not the page.\n" + f"Page output snippet: {output[:2000]}" + ) + assert "TestState" not in output, ( + "The state-context wiring should live inside the memo body, not the page.\n" + f"Page output snippet: {output[:2000]}" + ) + + +@pytest.mark.parametrize( + "cls", + [ + pytest.param(StyleEl, id="style"), + pytest.param(Textarea, id="textarea"), + pytest.param(Script, id="script"), + ], +) +def test_text_only_element_with_stateful_var_child_does_not_wrap_bare( + cls: type[Component], +) -> None: + """Text-only HTML elements must not wrap stateful Bare children as components. + + ``<style>``/``<textarea>``/``<script>`` all have raw-text content models. + A JSX component child renders as a stringified ``[object Object]`` — the + text interpolation needs to land inside the element's snapshot body. + """ + component = cls.create(STATE_VAR) + ctx, page_ctx = _compile_single_page(lambda: component) + + assert len(ctx.memoize_wrappers) == 1, ( + f"Expected exactly one snapshot wrapper; got: {list(ctx.memoize_wrappers)}" + ) + wrapper_tag = next(iter(ctx.memoize_wrappers)) + output = page_ctx.output_code + assert output is not None + assert f"jsx({wrapper_tag}," in output, ( + "The page output must call the raw-text element's snapshot wrapper.\n" + f"Page output snippet: {output[:2000]}" + ) + assert "useTestState" not in output, ( + "The state-bearing hook must live inside the memo body, not the page.\n" + f"Page output snippet: {output[:2000]}" + ) + assert "TestState" not in output, ( + "The state-context wiring must live inside the memo body, not the page.\n" + f"Page output snippet: {output[:2000]}" + ) + + +def test_accordion_trigger_with_stateful_cond_is_memoized() -> None: + """AccordionTrigger holding a stateful cond wraps as a single snapshot. + + AccordionTrigger sets ``recursive=False`` without inheriting from + ``MemoizationLeaf``; the boundary itself must memoize so the cond's + state read lands inside the snapshot rather than the page module. + """ + from reflex_components_radix.primitives.accordion import AccordionTrigger + + trigger = AccordionTrigger.create( + rx.cond( + SpecialFormMemoState.flag, + rx.text("Hide"), + rx.text("Show"), + ) + ) + ctx, page_ctx = _compile_single_page(lambda: trigger) + + wrapper_tags = list(ctx.memoize_wrappers) + trigger_wrappers = [t for t in wrapper_tags if "trigger" in t.lower()] + assert trigger_wrappers, ( + "AccordionTrigger with a stateful cond must produce its own snapshot " + f"wrapper. Got wrappers: {wrapper_tags}" + ) + + output = page_ctx.output_code + assert output is not None + assert "useContext(StateContexts" not in output, ( + "State read leaked into the page module — the trigger's stateful cond " + "should be captured inside the snapshot wrapper instead.\n" + f"Page output snippet: {output[:2000]}" + ) + + +@pytest.mark.parametrize( + "component_cls", + [ + # text-only (RCDATA / raw text) + pytest.param(Title, id="title"), + pytest.param(StyleEl, id="style"), + pytest.param(Textarea, id="textarea"), + pytest.param(Script, id="script"), + pytest.param(Noscript, id="noscript"), + pytest.param(ScriptTag, id="script_tag"), + # void HTML elements + pytest.param(Meta, id="meta"), + pytest.param(Base, id="base"), + pytest.param(Link, id="link"), + pytest.param(RawLink, id="raw_link"), + pytest.param(BaseInput, id="input"), + pytest.param(Br, id="br"), + pytest.param(Wbr, id="wbr"), + pytest.param(Col, id="col"), + pytest.param(Hr, id="hr"), + pytest.param(Area, id="area"), + pytest.param(Img, id="img"), + pytest.param(Track, id="track"), + pytest.param(Embed, id="embed"), + pytest.param(Source, id="source"), + # SVG raw-text equivalents + pytest.param(Desc, id="svg_desc"), + pytest.param(SvgTitle, id="svg_title"), + pytest.param(SvgScript, id="svg_script"), + pytest.param(SvgStyle, id="svg_style"), + ], +) +def test_restricted_content_element_isolates_stateful_bare_via_snapshot( + component_cls: type[Component], +) -> None: + """Restricted-content elements snapshot-wrap and never expose a Bare child. + + Asserts both the classification (the element opts into SNAPSHOT) and the + invariant (a stateful Bare child stays inside the snapshot rather than + being independently wrapped as a JSX component child of an element whose + content model rejects components). + """ + from reflex_base.components.memoize_helpers import is_snapshot_boundary + + instance = component_cls.create() + assert is_snapshot_boundary(instance), ( + f"{component_cls.__qualname__} should be classified as a snapshot boundary." + ) + assert get_memoization_strategy(instance) is MemoizationStrategy.SNAPSHOT, ( + f"{component_cls.__qualname__} should use SNAPSHOT strategy" + ) + + ctx, page_ctx = _compile_single_page(lambda: component_cls.create(STATE_VAR)) + assert len(ctx.memoize_wrappers) == 1, ( + f"Expected exactly one snapshot wrapper for {component_cls.__qualname__}, " + f"got: {list(ctx.memoize_wrappers)}" + ) + wrapper_tag = next(iter(ctx.memoize_wrappers)) + output = page_ctx.output_code + assert output is not None + assert f"jsx({wrapper_tag}," in output, ( + f"Page output for {component_cls.__qualname__} must call the snapshot " + f"wrapper.\nPage output snippet: {output[:2000]}" + ) + assert "useTestState" not in output, ( + f"Stateful Bare child of <{getattr(component_cls, 'tag', '?')}> " + f"({component_cls.__qualname__}) leaked the state hook into the page; " + "the element's snapshot must capture it.\n" + f"Page output snippet: {output[:2000]}" + ) + assert "TestState" not in output, ( + f"State-context wiring for <{getattr(component_cls, 'tag', '?')}> " + f"({component_cls.__qualname__}) leaked into the page module.\n" + f"Page output snippet: {output[:2000]}" + ) + + +def _compile_memo_module_text(ctx: CompileContext) -> str: + """Compile the auto-memo definitions and return the concatenated JSX text. + + Args: + ctx: The compile context produced by ``_compile_single_page``. + + Returns: + The full memo module source code joined by newlines. + """ + from reflex.compiler.compiler import compile_memo_components + + memo_files, _imports = compile_memo_components( + components=(), + experimental_memos=tuple(ctx.auto_memo_components.values()), + ) + return "\n".join(code for _, code in memo_files) + + +def test_title_memo_body_renders_text_interpolation_not_bare_component() -> None: + """The title's memo body must interpolate the state Var as text. + + The body must contain a literal ``jsx("title", …)`` call carrying the + state-context wiring, and the page module must subscribe via the wrapper + rather than directly. The state hook/context lives in the memo body only. + """ + ctx, page_ctx = _compile_single_page(lambda: Title.create(STATE_VAR)) + memo_code = _compile_memo_module_text(ctx) + + assert 'jsx("title"' in memo_code, ( + f'Title snapshot body should contain a literal ``jsx("title", …)`` ' + f"call. Memo code:\n{memo_code[:2000]}" + ) + assert "useTestState" in memo_code, ( + "Title memo body should carry the stateful hook so the Bare child is " + f"interpolated inline, not lifted out.\nMemo code:\n{memo_code[:2000]}" + ) + + page_output = page_ctx.output_code + assert page_output is not None + wrapper_tag = next(iter(ctx.memoize_wrappers)) + assert f"jsx({wrapper_tag}," in page_output + assert "useTestState" not in page_output + + +def test_meta_memo_body_renders_void_element_inline() -> None: + """Meta's snapshot body should call ``jsx("meta", …)`` and own the state.""" + ctx, _page_ctx = _compile_single_page(lambda: Meta.create(STATE_VAR)) + memo_code = _compile_memo_module_text(ctx) + + assert 'jsx("meta"' in memo_code + assert "useTestState" in memo_code, ( + "Meta memo body should carry the stateful hook so the Bare child is " + f"interpolated inline, not lifted out.\nMemo code:\n{memo_code[:2000]}" + ) + + +def test_snapshot_boundary_memo_body_subscribes_state_in_body_not_page() -> None: + """State subscription wiring lives in the memo body, not in the page module. + + The whole point of memoization is to isolate state reads from the page. + This asserts that ``useContext(StateContexts…)`` (state subscription) + appears in the memo module and NOT in the page output, confirming the + state read landed inside the snapshot wrapper. + """ + from reflex_components_radix.primitives.accordion import AccordionTrigger + + trigger = AccordionTrigger.create( + rx.cond( + SpecialFormMemoState.flag, + rx.text("Hide"), + rx.text("Show"), + ) + ) + ctx, page_ctx = _compile_single_page(lambda: trigger) + memo_code = _compile_memo_module_text(ctx) + + assert "useContext(StateContexts" in memo_code, ( + "Snapshot wrapper should subscribe to state inside the memo body." + ) + page_output = page_ctx.output_code + assert page_output is not None + assert "useContext(StateContexts" not in page_output, ( + "State subscription should NOT appear in the page module — it must be " + "isolated inside the snapshot wrapper.\n" + f"Page output:\n{page_output[:2000]}" + ) + + +def test_nested_snapshot_boundaries_produce_one_outer_wrapper() -> None: + """A snapshot boundary inside another snapshot boundary produces ONE wrapper. + + The outer boundary's suppressor stack must absorb the inner boundary into + its own snapshot. Two nested wrappers would both duplicate the inner + component AND defeat the boundary's "I own my subtree" contract. + """ + inner = LeafComponent.create(Plain.create(STATE_VAR)) + outer = LeafComponent.create(inner) + ctx, _page_ctx = _compile_single_page(lambda: outer) + + assert len(ctx.memoize_wrappers) == 1, ( + f"Nested snapshot boundaries must collapse to one outer wrapper; got " + f"{list(ctx.memoize_wrappers)}" + ) + + memo_code = _compile_memo_module_text(ctx) + assert "jsx(LeafComponent" in memo_code, ( + "The outer wrapper's body must render the inner LeafComponent so the " + "suppressed inner boundary still appears under the outer snapshot.\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + assert "jsx(Plain" in memo_code, ( + "The outer wrapper's body must render the innermost Plain component " + "and its Bare child so the stateful subtree lands inside the snapshot.\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + + +def test_memoization_leaf_subclass_and_raw_recursive_false_behave_identically() -> None: + """Both ways to opt into recursive=False produce one snapshot wrapper. + + ``MemoizationLeaf`` subclasses and components that simply set + ``_memoization_mode = MemoizationMode(recursive=False)`` are handled + equivalently by the compiler. + """ + from reflex_base.components.component import MemoizationLeaf + + class LeafSubclass(MemoizationLeaf): + tag = "LeafSubclass" + library = "leaf-subclass-lib" + + leaf_subclass = LeafSubclass.create(Plain.create(STATE_VAR)) + raw_leaf = LeafComponent.create(Plain.create(STATE_VAR)) + + ctx_a, _ = _compile_single_page(lambda: leaf_subclass) + ctx_b, _ = _compile_single_page(lambda: raw_leaf) + + assert len(ctx_a.memoize_wrappers) == 1 + assert len(ctx_b.memoize_wrappers) == 1 + + +def test_snapshot_boundary_with_multiple_stateful_descendants_emits_one_wrapper() -> ( + None +): + """One boundary + many stateful descendants = one wrapper (not one per descendant). + + Without this invariant, a Radix primitive wrapping several stateful + children would balloon the page with one wrapper per child even though + the boundary already owns the subtree. + """ + boundary = LeafComponent.create( + Plain.create(STATE_VAR), + Plain.create(STATE_VAR), + WithProp.create(label=STATE_VAR), + ) + ctx, _page_ctx = _compile_single_page(lambda: boundary) + assert len(ctx.memoize_wrappers) == 1, ( + f"Multiple stateful descendants must share the boundary's wrapper; got " + f"{list(ctx.memoize_wrappers)}" + ) + + memo_code = _compile_memo_module_text(ctx) + assert memo_code.count("jsx(Plain") == 2, ( + "The boundary's snapshot body must render both Plain children inline.\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + assert "jsx(WithProp" in memo_code, ( + "The boundary's snapshot body must render the WithProp child inline.\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + + +def test_repeated_snapshot_boundary_subtrees_dedupe_to_one_definition() -> None: + """Two identical boundary subtrees collapse to one memo definition. + + Memo definitions are keyed on the rendered subtree shape, so two + identical boundaries should share a wrapper tag (even though they appear + twice on the page). + """ + ctx, page_ctx = _compile_single_page( + lambda: Fragment.create( + LeafComponent.create(Plain.create(STATE_VAR)), + LeafComponent.create(Plain.create(STATE_VAR)), + ) + ) + assert len(ctx.memoize_wrappers) == 1, ( + f"Identical boundary subtrees should share one wrapper; got " + f"{list(ctx.memoize_wrappers)}" + ) + wrapper_tag = next(iter(ctx.memoize_wrappers)) + assert (page_ctx.output_code or "").count(f"jsx({wrapper_tag},") == 2 + + +def test_snapshot_boundary_with_stateful_prop_and_descendant_emits_one_wrapper() -> ( + None +): + """A boundary with both stateful props and stateful descendants memoizes once.""" + from reflex_components_core.el.elements.metadata import Title + + title = Title.create( + STATE_VAR, # stateful child Bare + class_name=STATE_VAR.to(str), # stateful prop + ) + ctx, _page_ctx = _compile_single_page(lambda: title) + assert len(ctx.memoize_wrappers) == 1 + + +def test_disposition_never_overrides_snapshot_boundary_subtree_check() -> None: + """``MemoizationDisposition.NEVER`` wins even with a stateful subtree. + + Snapshot boundaries that explicitly opt out via NEVER must stay + unwrapped — useful for components that do their own memoization + elsewhere or shouldn't be memoized for correctness reasons. + """ + boundary = LeafComponent.create(Plain.create(STATE_VAR)) + object.__setattr__( + boundary, + "_memoization_mode", + dataclasses.replace( + boundary._memoization_mode, + disposition=MemoizationDisposition.NEVER, + ), + ) + assert not _should_memoize(boundary) + + +def test_static_subtree_inside_passthrough_no_memo_at_all() -> None: + """Sanity: a fully static page produces no memo wrappers. + + Guards against a regression where the new branch incorrectly fires for + components without state hooks. + """ + ctx, _page_ctx = _compile_single_page( + lambda: rx.box(rx.text("static"), rx.text("also static")) + ) + assert len(ctx.memoize_wrappers) == 0, ( + f"No state, no wrappers expected. Got: {list(ctx.memoize_wrappers)}" + ) + + +def test_void_element_with_only_stateful_prop_memoizes_via_snapshot() -> None: + """A void element with only a stateful prop still snapshot-wraps cleanly. + + Verifies that even without children, stateful props on void elements go + through the boundary's snapshot wrapper rather than degrading to a + passthrough that re-reads state on the page. + """ + from reflex_components_core.el.elements.media import Img + + img = Img.create(src=STATE_VAR.to(str)) + ctx, page_ctx = _compile_single_page(lambda: img) + assert len(ctx.memoize_wrappers) == 1 + output = page_ctx.output_code + assert output is not None + assert "useContext(StateContexts" not in output + + +@pytest.mark.parametrize( + "factory", + [ + pytest.param(lambda: Title.create("hello", id="t"), id="title_with_id"), + pytest.param(lambda: Img.create(src="/x.png", id="logo"), id="img_with_id"), + pytest.param(lambda: Br.create(id="br"), id="br_with_id"), + pytest.param(lambda: BaseInput.create(id="i"), id="input_with_id"), + pytest.param( + lambda: Meta.create(name="description", id="m"), id="meta_with_id" + ), + ], +) +def test_static_restricted_element_with_id_only_does_not_memoize( + factory: Callable[[], Component], +) -> None: + """Restricted-content elements with only an ``id`` ref stay unwrapped. + + The subtree scan subtracts the static-id ``useRef`` line from the + component's internal hooks so id-only elements do not flag as reactive. + Both ``MemoizationLeaf``-based elements and components that explicitly + set ``_memoization_mode = recursive=False`` go through this same path. + """ + component = factory() + ctx, _page_ctx = _compile_single_page(lambda: component) + assert len(ctx.memoize_wrappers) == 0, ( + f"Static restricted element with only an id ref should not memoize. " + f"Got wrappers: {list(ctx.memoize_wrappers)}" + ) + + +def test_static_restricted_element_no_id_no_children_does_not_memoize() -> None: + """Sanity: a fully static restricted element with no props/children stays unwrapped.""" + from reflex_components_core.el.elements.metadata import Title + + ctx, _page_ctx = _compile_single_page(lambda: Title.create("static-string")) + assert len(ctx.memoize_wrappers) == 0, ( + f"Static title should not memoize. Got: {list(ctx.memoize_wrappers)}" + ) + + +@pytest.mark.parametrize("global_ref", [True, False]) +def test_client_state_value_inside_snapshot_boundary_is_memoized( + global_ref: bool, +) -> None: + """Client-state Vars are reactive and must trigger boundary memoization. + + A ``client_state`` Var contributes its ``useState``/``useId`` hooks via + ``var_data.hooks`` without setting ``var_data.state``. The reactive-Var + walk must catch the hooks-only case so client-state-driven content + inside a snapshot boundary lands in the memo body. Both global and + page-local ``ClientStateVar`` Vars must drive the same wrapping. + """ + from reflex.experimental.client_state import ClientStateVar + + cs_var = ClientStateVar.create("titletest", default="hi", global_ref=global_ref) + title = Title.create(cs_var.value) + ctx, page_ctx = _compile_single_page(lambda: title) + assert len(ctx.memoize_wrappers) == 1, ( + "Client-state-driven title content must memoize. Got: " + f"{list(ctx.memoize_wrappers)}" + ) + page_output = page_ctx.output_code + assert page_output is not None + assert "useState" not in page_output, ( + "Client-state hooks should be inside the memo body, not the page.\n" + f"Page output snippet: {page_output[:2000]}" + ) + + +def test_hooks_only_var_data_descendant_inside_snapshot_boundary_is_memoized() -> None: + """Hook-bearing VarData without ``state`` still triggers snapshot memoization. + + Some frontend-only Vars contribute React hooks but do not carry a backend + state name. The snapshot-boundary subtree scan must catch those hooks-only + Vars so their hook lines land in the memo body instead of being suppressed + with the descendant. + """ + hook_var = Var(_js_expr="hookOnlyProbe")._replace( + merge_var_data=VarData(hooks={"const hookOnlyProbe = useHookOnly();": None}) + ) + child = Plain.create() + child.special_props = [hook_var] + boundary = LeafComponent.create(child) + + ctx, page_ctx = _compile_single_page(lambda: boundary) + memo_code = _compile_memo_module_text(ctx) + + assert len(ctx.memoize_wrappers) == 1, ( + f"Hooks-only descendant should produce one boundary wrapper, got: " + f"{list(ctx.memoize_wrappers)}" + ) + assert "useHookOnly" in memo_code, ( + "Hooks-only VarData should be emitted in the memo body.\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + page_output = page_ctx.output_code + assert page_output is not None + assert "useHookOnly" not in page_output, ( + "Hooks-only VarData leaked into the page module.\n" + f"Page output snippet: {page_output[:2000]}" + ) + + +def test_added_hook_descendant_inside_snapshot_boundary_is_memoized() -> None: + """Hooks from ``add_hooks`` descendants trigger snapshot memoization. + + ``add_hooks`` output does not necessarily appear in any Var or event + trigger. Snapshot boundaries must still wrap so the walker skips the + descendant and the hook lands in the memo body, matching the signal used + by ``MemoizationLeaf.create``. + """ + + class HookOnlyChild(Component): + tag = "HookOnlyChild" + library = "hook-only-child-lib" + + def add_hooks(self) -> list[str]: + """Add a hook line for the regression test. + + Returns: + The hook lines this component contributes. + """ + return ["const hookOnlyChild = useHookOnlyChild();"] + + boundary = LeafComponent.create(HookOnlyChild.create()) + assert _should_memoize(boundary) + + ctx, page_ctx = _compile_single_page(lambda: boundary) + memo_code = _compile_memo_module_text(ctx) + + assert len(ctx.memoize_wrappers) == 1, ( + f"Added-hook descendant should produce one boundary wrapper, got: " + f"{list(ctx.memoize_wrappers)}" + ) + assert "useHookOnlyChild" in memo_code, ( + "add_hooks output should be emitted in the memo body.\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + page_output = page_ctx.output_code + assert page_output is not None + assert "useHookOnlyChild" not in page_output, ( + "add_hooks output leaked into the page module.\n" + f"Page output snippet: {page_output[:2000]}" + ) + + +@pytest.mark.parametrize( + "factory", + [ + pytest.param(lambda: Meta.create(content=STATE_VAR), id="meta_content"), + pytest.param(lambda: Base.create(href=STATE_VAR), id="base_href"), + pytest.param(lambda: Link.create(href=STATE_VAR), id="link_href"), + pytest.param(lambda: Script.create(src=STATE_VAR), id="script_src"), + pytest.param(lambda: BaseInput.create(value=STATE_VAR), id="input_value"), + pytest.param(lambda: Textarea.create(value=STATE_VAR), id="textarea_value"), + pytest.param(lambda: SvgStyle.create(media=STATE_VAR), id="svg_style_media"), + ], +) +def test_restricted_content_element_with_stateful_attribute_uses_snapshot( + factory: Callable[[], Component], +) -> None: + """Stateful attrs on restricted-content elements are isolated in snapshots.""" + ctx, page_ctx = _compile_single_page(factory) + memo_code = _compile_memo_module_text(ctx) + + assert len(ctx.memoize_wrappers) == 1, ( + f"Expected one snapshot wrapper for a stateful restricted attr, got: " + f"{list(ctx.memoize_wrappers)}" + ) + page_output = page_ctx.output_code + assert page_output is not None + assert "useTestState" not in page_output, ( + "Reactive hook marker for restricted attr should not leak to page output.\n" + f"Page output snippet: {page_output[:2000]}" + ) + assert "useTestState" in memo_code, ( + "Reactive hook marker for restricted attr should live in the memo body.\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + + +def _real_recursive_false_factories() -> list: + from reflex_components_radix.primitives.dialog import DialogTrigger + from reflex_components_radix.primitives.drawer import DrawerTrigger + from reflex_components_radix.themes.components.popover import PopoverTrigger + from reflex_components_radix.themes.components.tabs import TabsTrigger + from reflex_components_radix.themes.components.tooltip import Tooltip + + return [ + pytest.param( + lambda: DialogTrigger.create(rx.text(STATE_VAR)), + id="primitive_dialog_trigger", + ), + pytest.param( + lambda: DrawerTrigger.create(rx.text(STATE_VAR)), + id="primitive_drawer_trigger", + ), + pytest.param( + lambda: PopoverTrigger.create(rx.text(STATE_VAR)), + id="themes_popover_trigger", + ), + pytest.param( + lambda: TabsTrigger.create(rx.text(STATE_VAR), value="tab-1"), + id="themes_tabs_trigger", + ), + pytest.param( + lambda: Tooltip.create(rx.text(STATE_VAR), content="tip"), + id="themes_tooltip", + ), + ] + + +@pytest.mark.parametrize("factory", _real_recursive_false_factories()) +def test_real_recursive_false_components_with_stateful_descendants_snapshot_wrap( + factory: Callable[[], Component], +) -> None: + """Several real ``recursive=False`` components share the boundary behavior.""" + component = factory() + ctx, page_ctx = _compile_single_page(lambda: component) + + assert len(ctx.memoize_wrappers) == 1, ( + f"Expected one wrapper for {type(component).__qualname__}, got: " + f"{list(ctx.memoize_wrappers)}" + ) + wrapper_tag = next(iter(ctx.memoize_wrappers)) + assert type(component).__name__.lower() in wrapper_tag.lower(), ( + f"Wrapper {wrapper_tag!r} should be derived from {type(component).__name__}." + ) + page_output = page_ctx.output_code + assert page_output is not None + assert "useContext(StateContexts" not in page_output, ( + "Stateful descendant under a real snapshot boundary leaked to page output.\n" + f"Page output snippet: {page_output[:2000]}" + ) + + +def test_restricted_content_element_with_id_and_stateful_child_still_memoizes() -> None: + """Static ref filtering must not suppress real stateful content.""" + from reflex_components_core.el.elements.metadata import Title + + title = Title.create(STATE_VAR, id="stateful-title") + ctx, page_ctx = _compile_single_page(lambda: title) + memo_code = _compile_memo_module_text(ctx) + + assert len(ctx.memoize_wrappers) == 1, ( + f"Stateful title with id should still memoize, got: {list(ctx.memoize_wrappers)}" + ) + page_output = page_ctx.output_code + assert page_output is not None + assert "ref_stateful_title" not in page_output, ( + "The title ref should move with the snapshot body, not stay on the page.\n" + f"Page output snippet: {page_output[:2000]}" + ) + assert "ref_stateful_title" in memo_code, ( + "The title ref should be emitted inside the snapshot memo body.\n" + f"Memo code snippet: {memo_code[:2000]}" + ) + + +def test_each_memo_wrapper_emits_one_component_module_file() -> None: + """Every wrapper tag corresponds to exactly one ``components/{tag}.jsx`` file. + + Locks the per-wrapper file invariant: ``compile_memo_components`` must + emit one module per wrapper (plus the shared index), so that React can + code-split per wrapper. A wrapper without a file (or a file without a + wrapper) would mean broken imports at runtime. + """ + from reflex.compiler.compiler import compile_memo_components + + ctx, _page_ctx = _compile_single_page( + lambda: Fragment.create( + Plain.create(STATE_VAR), + WithProp.create(label=STATE_VAR), + LeafComponent.create(Plain.create(STATE_VAR)), + ) + ) + memo_files, _imports = compile_memo_components( + components=(), + experimental_memos=tuple(ctx.auto_memo_components.values()), + ) + component_module_names = { + Path(path).name + for path, _ in memo_files + if Path(path).parent.name == "components" + } + expected = {f"{tag}.jsx" for tag in ctx.memoize_wrappers} + assert component_module_names == expected, ( + f"Per-wrapper file invariant broken. wrappers={sorted(ctx.memoize_wrappers)} " + f"files={sorted(component_module_names)}" + ) + assert len(ctx.memoize_wrappers) == 3, ( + "Test should exercise the multi-wrapper case: one passthrough wrapper " + "for Plain, one for WithProp, and one snapshot wrapper for the " + f"LeafComponent boundary. Got: {sorted(ctx.memoize_wrappers)}" + ) diff --git a/tests/units/compiler/test_plugins.py b/tests/units/compiler/test_plugins.py new file mode 100644 index 00000000000..26eb1f39c99 --- /dev/null +++ b/tests/units/compiler/test_plugins.py @@ -0,0 +1,1142 @@ +# ruff: noqa: D101, D102 + +import dataclasses +from collections.abc import Callable +from typing import Any + +import pytest +from reflex_base.components.component import ( + BaseComponent, + Component, + ComponentStyle, + field, +) +from reflex_base.plugins import ( + BaseContext, + CompileContext, + CompilerHooks, + ComponentAndChildren, + PageContext, + PageDefinition, + Plugin, +) +from reflex_base.plugins.base import HookOrder +from reflex_base.utils import format as format_utils +from reflex_base.utils.imports import ImportVar, collapse_imports, merge_imports +from reflex_base.vars import VarData +from reflex_base.vars.base import LiteralVar, Var +from reflex_components_core.base.fragment import Fragment + +from reflex.app import UnevaluatedPage +from reflex.compiler import compiler +from reflex.compiler.plugins import ( + ApplyStylePlugin, + DefaultCollectorPlugin, + DefaultPagePlugin, + default_page_plugins, +) + + +@dataclasses.dataclass(slots=True) +class FakePage: + route: str + component: Callable[[], Component] + title: Var | str | None = None + description: Var | str | None = None + image: str = "" + meta: tuple[dict[str, Any], ...] = () + + +class WrapperComponent(Component): + tag = "WrapperComponent" + library = "wrapper-lib" + + @staticmethod + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: + return {(20, "NestedWrap"): Fragment.create()} + + +class RootComponent(Component): + tag = "RootComponent" + library = "root-lib" + + slot: Component | None = field(default=None) + + def add_style(self) -> dict[str, Any] | None: + return {"display": "flex"} + + def add_custom_code(self) -> list[str]: + return ["const rootAddedCode = 1;"] + + @staticmethod + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: + return {(10, "Wrap"): WrapperComponent.create()} + + +class ChildComponent(Component): + tag = "ChildComponent" + library = "child-lib" + + def add_style(self) -> dict[str, Any] | None: + return {"align_items": "center"} + + def add_custom_code(self) -> list[str]: + return ["const childAddedCode = 1;"] + + def _get_custom_code(self) -> str | None: + return "const childCustomCode = 1;" + + def _get_hooks(self) -> str | None: + return "const childHook = useChildHook();" + + +class PropComponent(Component): + tag = "PropComponent" + library = "prop-lib" + + def add_custom_code(self) -> list[str]: + return ["const propAddedCode = 1;"] + + def _get_custom_code(self) -> str | None: + return "const propCustomCode = 1;" + + def _get_dynamic_imports(self) -> str | None: + return "dynamic(() => import('prop-lib'))" + + def _get_hooks(self) -> str | None: + return "const propHook = usePropHook();" + + @staticmethod + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: + return {(15, "PropWrap"): Fragment.create()} + + +class SharedLibraryComponent(Component): + tag = "SharedLibraryComponent" + library = "react-moment" + + @staticmethod + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: + return {(25, "SharedLibraryWrap"): Fragment.create()} + + +class InlineStatefulComponent(Component): + tag = "InlineStatefulComponent" + library = "inline-lib" + + +class ReplacementComponent(Component): + tag = "ReplacementComponent" + library = "replacement-lib" + + def _get_custom_code(self) -> str | None: + return "const replacementCustomCode = 1;" + + +class StubPlugin(Plugin): + pass + + +SHARED_STATEFUL_VAR = LiteralVar.create("shared")._replace( + merge_var_data=VarData( + hooks={"useSharedStatefulValue": None}, + state="SharedState", + ) +) + +INLINE_STATEFUL_VAR = LiteralVar.create("inline")._replace( + merge_var_data=VarData( + hooks={"useInlineStatefulValue": None}, + state="InlineState", + ) +) + + +def create_component_tree() -> RootComponent: + return RootComponent.create( + ChildComponent.create(id="child-id", style={"color": "red"}), + slot=PropComponent.create(id="prop-id", style={"opacity": "0.5"}), + style={"margin": "0"}, + ) + + +def create_shared_stateful_component() -> SharedLibraryComponent: + return SharedLibraryComponent.create(SHARED_STATEFUL_VAR) + + +def create_inline_stateful_component() -> InlineStatefulComponent: + return InlineStatefulComponent.create(INLINE_STATEFUL_VAR) + + +def page_style() -> ComponentStyle: + return { + RootComponent: {"padding": "1rem"}, + ChildComponent: {"font_size": "12px"}, + PropComponent: {"border": "1px solid green"}, + } + + +def normalize_style(component: BaseComponent) -> dict[str, str]: + assert isinstance(component, Component) + return {key: str(value) for key, value in component.style.items()} + + +def create_compile_context(hooks: CompilerHooks | None = None) -> CompileContext: + return CompileContext(pages=[], hooks=hooks or CompilerHooks()) + + +def collect_page_context( + component: BaseComponent, + *, + plugins: tuple[Any, ...], +) -> PageContext: + page_ctx = PageContext( + name="page", + route="/page", + root_component=component, + ) + hooks = CompilerHooks(plugins=plugins) + compile_ctx = create_compile_context(hooks) + + with compile_ctx, page_ctx: + page_ctx.root_component = hooks.compile_component( + page_ctx.root_component, + page_context=page_ctx, + compile_context=compile_ctx, + ) + hooks.compile_page(page_ctx, compile_context=compile_ctx) + + return page_ctx + + +def test_eval_page_uses_first_non_none_result() -> None: + calls: list[str] = [] + page = FakePage(route="/demo", component=lambda: Fragment.create()) + + class NoMatchPlugin(StubPlugin): + def eval_page( + self, + page_fn: Any, + /, + *, + page: PageDefinition, + **kwargs: Any, + ) -> None: + del page_fn, page, kwargs + calls.append("no-match") + + class MatchPlugin(StubPlugin): + def eval_page( + self, + page_fn: Any, + /, + *, + page: PageDefinition, + **kwargs: Any, + ) -> PageContext: + del kwargs + calls.append("match") + return PageContext( + name="page", + route=page.route, + root_component=page_fn(), + ) + + class UnreachablePlugin(StubPlugin): + def eval_page( + self, + page_fn: Any, + /, + *, + page: PageDefinition, + **kwargs: Any, + ) -> PageContext: + del page_fn, page, kwargs + calls.append("unreachable") + msg = "eval_page should stop at the first page context" + raise AssertionError(msg) + + hooks = CompilerHooks(plugins=(NoMatchPlugin(), MatchPlugin(), UnreachablePlugin())) + + page_ctx = hooks.eval_page(page.component, page=page, compile_context=None) + + assert page_ctx is not None + assert page_ctx.route == "/demo" + assert calls == ["no-match", "match"] + + +def test_compile_page_runs_plugins_in_registration_order() -> None: + calls: list[str] = [] + page_ctx = PageContext( + name="page", + route="/ordered", + root_component=Fragment.create(), + ) + + class FirstPlugin(StubPlugin): + def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + del page_ctx, kwargs + calls.append("first") + + class SecondPlugin(StubPlugin): + def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + del page_ctx, kwargs + calls.append("second") + + hooks = CompilerHooks(plugins=(FirstPlugin(), SecondPlugin())) + hooks.compile_page(page_ctx, compile_context=None) + + assert calls == ["first", "second"] + + +def test_component_hook_resolution_caches_only_real_overrides() -> None: + class EnterPlugin(StubPlugin): + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del comp, page_context, compile_context, in_prop_tree + + class LeavePlugin(StubPlugin): + def leave_component( + self, + comp: BaseComponent, + children: tuple[BaseComponent, ...], + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del ( + comp, + children, + page_context, + compile_context, + in_prop_tree, + ) + + hooks = CompilerHooks(plugins=(Plugin(), EnterPlugin(), LeavePlugin())) + + assert len(hooks._enter_component_hook_binders) == 1 + assert len(hooks._leave_component_hook_binders) == 1 + + +def test_enter_component_skips_inherited_base_plugin_hook( + monkeypatch: pytest.MonkeyPatch, +) -> None: + visited: list[str] = [] + root = RootComponent.create() + + def fail_enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del self, comp, page_context, compile_context, in_prop_tree + msg = "Inherited Plugin.enter_component hook should be skipped." + raise AssertionError(msg) + + monkeypatch.setattr(Plugin, "enter_component", fail_enter_component) + + class RealPlugin(StubPlugin): + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del page_context, compile_context, in_prop_tree + visited.append(type(comp).__name__) + + hooks = CompilerHooks(plugins=(Plugin(), RealPlugin())) + page_ctx = PageContext(name="page", route="/page", root_component=root) + compile_ctx = create_compile_context(hooks) + + with compile_ctx, page_ctx: + hooks.compile_component( + root, + page_context=page_ctx, + compile_context=compile_ctx, + ) + + assert visited == ["RootComponent"] + + +def test_enter_component_skips_inherited_protocol_hook( + monkeypatch: pytest.MonkeyPatch, +) -> None: + visited: list[str] = [] + root = RootComponent.create() + + def fail_enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del self, comp, page_context, compile_context, in_prop_tree + msg = "Inherited Plugin.enter_component hook should be skipped." + raise AssertionError(msg) + + monkeypatch.setattr(Plugin, "enter_component", fail_enter_component) + + class ProtocolOnlyPlugin(Plugin): + pass + + class RealPlugin(StubPlugin): + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del page_context, compile_context, in_prop_tree + visited.append(type(comp).__name__) + + hooks = CompilerHooks(plugins=(ProtocolOnlyPlugin(), RealPlugin())) + page_ctx = PageContext(name="page", route="/page", root_component=root) + compile_ctx = create_compile_context(hooks) + + with compile_ctx, page_ctx: + hooks.compile_component( + root, + page_context=page_ctx, + compile_context=compile_ctx, + ) + + assert visited == ["RootComponent"] + + +def test_compile_component_orders_enter_and_leave_by_plugin() -> None: + events: list[str] = [] + root = RootComponent.create() + + class FirstPlugin(StubPlugin): + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del comp, page_context, compile_context, in_prop_tree + events.append("first:enter") + + def leave_component( + self, + comp: BaseComponent, + children: tuple[BaseComponent, ...], + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del ( + comp, + children, + page_context, + compile_context, + in_prop_tree, + ) + events.append("first:leave") + + class SecondPlugin(StubPlugin): + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del comp, page_context, compile_context, in_prop_tree + events.append("second:enter") + + def leave_component( + self, + comp: BaseComponent, + children: tuple[BaseComponent, ...], + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del ( + comp, + children, + page_context, + compile_context, + in_prop_tree, + ) + events.append("second:leave") + + hooks = CompilerHooks(plugins=(FirstPlugin(), SecondPlugin())) + page_ctx = PageContext(name="page", route="/page", root_component=root) + compile_ctx = create_compile_context(hooks) + + with compile_ctx, page_ctx: + compiled_root = hooks.compile_component( + root, + page_context=page_ctx, + compile_context=compile_ctx, + ) + + assert compiled_root is root + assert events == [ + "first:enter", + "second:enter", + "second:leave", + "first:leave", + ] + + +def test_compile_component_traverses_children_before_prop_components() -> None: + visited: list[str] = [] + root = RootComponent.create( + ChildComponent.create(), + slot=PropComponent.create(), + ) + + class VisitPlugin(StubPlugin): + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del page_context, compile_context, in_prop_tree + if isinstance(comp, Component): + visited.append(comp.tag or type(comp).__name__) + + hooks = CompilerHooks(plugins=(VisitPlugin(),)) + page_ctx = PageContext(name="page", route="/page", root_component=root) + compile_ctx = create_compile_context(hooks) + + with compile_ctx, page_ctx: + hooks.compile_component( + root, + page_context=page_ctx, + compile_context=compile_ctx, + ) + + assert visited == ["RootComponent", "ChildComponent", "PropComponent"] + + +def test_enter_and_leave_replacements_match_generator_style_behavior() -> None: + child = ChildComponent.create(id="original") + root = RootComponent.create(child) + + class ReplacePlugin(StubPlugin): + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> BaseComponent | ComponentAndChildren | None: + del page_context, compile_context + if isinstance(comp, RootComponent) and not in_prop_tree: + replacement_child = ChildComponent.create(id="replacement") + return comp, (replacement_child,) + return None + + def leave_component( + self, + comp: BaseComponent, + children: tuple[BaseComponent, ...], + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> BaseComponent | ComponentAndChildren | None: + del page_context, compile_context, in_prop_tree + if isinstance(comp, RootComponent): + return Fragment.create(comp), children + return None + + hooks = CompilerHooks(plugins=(ReplacePlugin(),)) + page_ctx = PageContext(name="page", route="/page", root_component=root) + compile_ctx = create_compile_context(hooks) + + with compile_ctx, page_ctx: + compiled_root = hooks.compile_component( + root, + page_context=page_ctx, + compile_context=compile_ctx, + ) + + assert isinstance(compiled_root, Fragment) + assert len(compiled_root.children) == 1 + replacement_child = compiled_root.children[0] + assert isinstance(replacement_child, ChildComponent) + assert str(replacement_child.id) == "replacement" + + +def test_context_lifecycle_and_cleanup() -> None: + compile_ctx = CompileContext(pages=[], hooks=CompilerHooks()) + page_ctx = PageContext( + name="page", + route="/ctx", + root_component=Fragment.create(), + ) + + with pytest.raises(RuntimeError, match="No active CompileContext"): + CompileContext.get() + with pytest.raises( + RuntimeError, match="must be entered with 'with' or 'async with'" + ): + compile_ctx.ensure_context_attached() + + with compile_ctx: + assert CompileContext.get() is compile_ctx + with pytest.raises(RuntimeError, match="No active PageContext"): + PageContext.get() + with page_ctx: + assert CompileContext.get() is compile_ctx + assert PageContext.get() is page_ctx + page_ctx.ensure_context_attached() + with pytest.raises(RuntimeError, match="No active PageContext"): + PageContext.get() + assert CompileContext.get() is compile_ctx + + with pytest.raises(RuntimeError, match="No active CompileContext"): + CompileContext.get() + + with pytest.raises(ValueError, match="boom"), compile_ctx: + msg = "boom" + raise ValueError(msg) + + with pytest.raises(RuntimeError, match="No active CompileContext"): + CompileContext.get() + + +def test_page_context_default_factories_are_isolated() -> None: + page_ctx_a = PageContext( + name="a", + route="/a", + root_component=Fragment.create(), + ) + page_ctx_b = PageContext( + name="b", + route="/b", + root_component=Fragment.create(), + ) + + page_ctx_a.imports.append({"lib-a": [ImportVar(tag="ThingA")]}) + page_ctx_a.module_code["const a = 1;"] = None + page_ctx_a.hooks["hookA"] = None + page_ctx_a.dynamic_imports.add("dynamic-a") + page_ctx_a.refs["refA"] = None + page_ctx_a.app_wrap_components[1, "WrapA"] = Fragment.create() + + assert page_ctx_b.imports == [] + assert page_ctx_b.module_code == {} + assert page_ctx_b.hooks == {} + assert page_ctx_b.dynamic_imports == set() + assert page_ctx_b.refs == {} + assert page_ctx_b.app_wrap_components == {} + + +def test_page_context_helpers_preserve_accumulated_values() -> None: + page_ctx = PageContext( + name="page", + route="/page", + root_component=Fragment.create(), + ) + page_ctx.imports.extend([ + {"lib-a": [ImportVar(tag="ThingA")]}, + {"lib-a": [ImportVar(tag="ThingB")], "lib-b": [ImportVar(tag="ThingC")]}, + ]) + page_ctx.module_code["const first = 1;"] = None + page_ctx.module_code["const second = 2;"] = None + + assert page_ctx.merged_imports() == merge_imports(*page_ctx.imports) + assert page_ctx.merged_imports(collapse=True) == collapse_imports( + merge_imports(*page_ctx.imports) + ) + assert list(page_ctx.custom_code_dict()) == [ + "const first = 1;", + "const second = 2;", + ] + + +def test_base_context_subclasses_initialize_distinct_context_vars() -> None: + class DynamicContext(BaseContext): + pass + + class AnotherDynamicContext(BaseContext): + pass + + assert DynamicContext.__context_var__ is not AnotherDynamicContext.__context_var__ + + +def test_apply_style_plugin_matches_legacy_style_behavior() -> None: + component = create_component_tree() + legacy_component = create_component_tree() + + legacy_component._add_style_recursive(page_style()) + + original_style_snapshot = normalize_style(component) + original_child_style_snapshot = normalize_style(component.children[0]) + + hooks = CompilerHooks(plugins=(ApplyStylePlugin(style=page_style()),)) + page_ctx = PageContext(name="page", route="/page", root_component=component) + compile_ctx = create_compile_context(hooks) + + with compile_ctx, page_ctx: + compiled = hooks.compile_component( + component, + page_context=page_ctx, + compile_context=compile_ctx, + ) + + assert normalize_style(compiled) == normalize_style(legacy_component) + assert normalize_style(compiled.children[0]) == normalize_style( + legacy_component.children[0] + ) + assert isinstance(compiled, type(legacy_component)) + assert compiled.slot is not None + assert legacy_component.slot is not None + assert normalize_style(compiled.slot) == normalize_style(legacy_component.slot) + + assert normalize_style(component) == original_style_snapshot + assert normalize_style(component.children[0]) == original_child_style_snapshot + + +def test_default_collector_matches_legacy_collectors() -> None: + component = create_component_tree() + assert "prop-lib" in component._get_all_imports(collapse=True) + + page_ctx = collect_page_context( + component, + plugins=(DefaultCollectorPlugin(),), + ) + + assert page_ctx.imports == [component._get_all_imports(collapse=True)] + assert "prop-lib" in page_ctx.frontend_imports + assert page_ctx.hooks == component._get_all_hooks() + assert "usePropHook" not in "".join(page_ctx.hooks) + assert page_ctx.module_code == component._get_all_custom_code() + assert page_ctx.dynamic_imports == component._get_all_dynamic_imports() + assert page_ctx.refs == component._get_all_refs() + assert page_ctx.refs == { + format_utils.format_ref("child-id"): None, + format_utils.format_ref("prop-id"): None, + } + assert ( + page_ctx.app_wrap_components.keys() + == component._get_all_app_wrap_components().keys() + ) + + +def test_default_collector_collects_nested_prop_tree_custom_code_without_recursion() -> ( + None +): + component = RootComponent.create( + slot=PropComponent.create( + ChildComponent.create(), + ) + ) + + page_ctx = collect_page_context( + component, + plugins=(DefaultCollectorPlugin(),), + ) + + assert page_ctx.module_code == component._get_all_custom_code() + assert "const propCustomCode = 1;" in page_ctx.module_code + assert "const childCustomCode = 1;" in page_ctx.module_code + + +def test_default_page_plugins_are_minimal_and_ordered() -> None: + from reflex.compiler.plugins.memoize import MemoizeStatefulPlugin + + plugins = default_page_plugins(style=page_style()) + + assert len(plugins) == 4 + assert isinstance(plugins[0], DefaultPagePlugin) + assert isinstance(plugins[1], ApplyStylePlugin) + assert isinstance(plugins[2], DefaultCollectorPlugin) + assert isinstance(plugins[3], MemoizeStatefulPlugin) + + +def test_compile_context_collects_artifacts_from_leave_replacement_plugins() -> None: + page = FakePage(route="/replacement", component=create_component_tree) + + class ReplaceRootPlugin(StubPlugin): + def leave_component( + self, + comp: BaseComponent, + children: tuple[BaseComponent, ...], + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> BaseComponent | None: + del page_context, compile_context, in_prop_tree + if isinstance(comp, RootComponent): + return ReplacementComponent.create(*children) + return None + + compile_ctx = CompileContext( + pages=[page], + hooks=CompilerHooks( + plugins=default_page_plugins(plugins=(ReplaceRootPlugin(),)) + ), + ) + + with compile_ctx: + compile_ctx.compile() + + page_ctx = compile_ctx.compiled_pages["/replacement"] + assert ( + page_ctx.root_component.render()["children"][0]["name"] + == "ReplacementComponent" + ) + assert "replacement-lib" in page_ctx.frontend_imports + assert "root-lib" not in page_ctx.frontend_imports + assert "const replacementCustomCode = 1;" in page_ctx.module_code + assert "const rootAddedCode = 1;" not in page_ctx.module_code + assert ("import {" + 'ReplacementComponent} from "replacement-lib"') in ( + page_ctx.output_code or "" + ) + assert ("import {" + 'RootComponent} from "root-lib"') not in ( + page_ctx.output_code or "" + ) + + +def test_leave_component_order_dispatches_pre_normal_post() -> None: + calls: list[str] = [] + + class LabelledLeavePlugin(StubPlugin): + label: str = "" + + def leave_component( + self, + comp: BaseComponent, + children: tuple[BaseComponent, ...], + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del children, page_context, compile_context, in_prop_tree + if isinstance(comp, RootComponent): + calls.append(self.label) + + class PrePlugin(LabelledLeavePlugin): + _compiler_leave_component_order = HookOrder.PRE + label = "pre" + + class NormalPlugin(LabelledLeavePlugin): + label = "normal" + + class PostPlugin(LabelledLeavePlugin): + _compiler_leave_component_order = HookOrder.POST + label = "post" + + component = create_component_tree() + hooks = CompilerHooks(plugins=(PostPlugin(), NormalPlugin(), PrePlugin())) + page_ctx = PageContext(name="page", route="/page", root_component=component) + compile_ctx = create_compile_context(hooks) + + with compile_ctx, page_ctx: + hooks.compile_component( + component, + page_context=page_ctx, + compile_context=compile_ctx, + ) + + assert calls == ["pre", "normal", "post"] + + +def test_compile_context_compiles_pages_and_matches_legacy_output() -> None: + page = FakePage(route="/demo", component=create_component_tree) + compile_ctx = CompileContext( + pages=[page], + hooks=CompilerHooks(plugins=default_page_plugins(style=page_style())), + ) + + with compile_ctx: + compiled_pages = compile_ctx.compile() + + assert compiled_pages is compile_ctx.compiled_pages + assert list(compiled_pages) == ["/demo"] + + page_ctx = compiled_pages["/demo"] + assert isinstance(page_ctx.root_component, Component) + assert page_ctx.name == "create_component_tree" + assert page_ctx.route == "/demo" + assert "prop-lib" in page_ctx.root_component._get_all_imports(collapse=True) + assert page_ctx.frontend_imports == page_ctx.merged_imports(collapse=True) + assert "prop-lib" in page_ctx.frontend_imports + compile_ctx_imports = collapse_imports(compile_ctx.all_imports) + for lib, fields in page_ctx.frontend_imports.items(): + assert lib in compile_ctx_imports + assert set(compile_ctx_imports[lib]) >= set(fields) + assert page_ctx.output_path is not None + assert page_ctx.output_code is not None + # `collapse_imports` uses `list(set(...))`, so the per-library ImportVar + # lists don't have a stable order across processes. Compare as sets. + [actual_imports] = page_ctx.imports + expected_imports = page_ctx.root_component._get_all_imports(collapse=True) + assert actual_imports.keys() == expected_imports.keys() + for lib, actual_vars in actual_imports.items(): + assert set(actual_vars) == set(expected_imports[lib]) + assert page_ctx.hooks == page_ctx.root_component._get_all_hooks() + assert page_ctx.module_code == page_ctx.root_component._get_all_custom_code() + assert ( + page_ctx.dynamic_imports == page_ctx.root_component._get_all_dynamic_imports() + ) + assert page_ctx.refs == page_ctx.root_component._get_all_refs() + assert ( + page_ctx.app_wrap_components.keys() + == page_ctx.root_component._get_all_app_wrap_components().keys() + ) + + legacy_component = compiler.compile_unevaluated_page( + page.route, + UnevaluatedPage( + component=page.component, + route=page.route, + title=page.title, + description=page.description, + image=page.image, + on_load=None, + meta=page.meta, + context={}, + ), + page_style(), + None, + ) + legacy_output = compiler.compile_page(page.route, legacy_component)[1] + + # The two compile paths produce the same content but the plugin pipeline + # inserts imports and hoistable const declarations in post-order (leaf + # first) while legacy inserts them in pre-order. Neither order matters to + # the JS engine — imports are hoisted, and the consts don't reference one + # another. Compare the preamble as a set of lines, and the component body + # (where hook order and JSX are meaningful) byte-for-byte. + preamble_marker = "export default function Component" + + def preamble_lines(output: str) -> set[str]: + preamble, _, _ = output.partition(preamble_marker) + return set(preamble.splitlines()) + + def component_body(output: str) -> str: + _, sep, body = output.partition(preamble_marker) + return sep + body + + assert preamble_lines(page_ctx.output_code) == preamble_lines(legacy_output) + assert component_body(page_ctx.output_code) == component_body(legacy_output) + + +def test_default_page_plugin_handles_var_backed_title_like_legacy_compiler() -> None: + page = UnevaluatedPage( + component=lambda: Fragment.create(), + route="/var-title", + title=Var(_js_expr="pageTitle", _var_type=str), + description=None, + image="", + on_load=None, + meta=(), + context={}, + ) + hooks = CompilerHooks(plugins=(DefaultPagePlugin(),)) + compile_ctx = create_compile_context(hooks) + + with compile_ctx: + page_ctx = hooks.eval_page( + page.component, + page=page, + compile_context=compile_ctx, + ) + + assert page_ctx is not None + + legacy_component = compiler.compile_unevaluated_page( + page.route, + page, + None, + None, + ) + assert page_ctx.root_component.render() == legacy_component.render() + + +def test_compile_context_rejects_duplicate_routes() -> None: + pages = [ + FakePage(route="/duplicate", component=lambda: Fragment.create()), + FakePage(route="/duplicate", component=lambda: Fragment.create()), + ] + compile_ctx = CompileContext( + pages=pages, + hooks=CompilerHooks(plugins=(DefaultPagePlugin(),)), + ) + + with ( + compile_ctx, + pytest.raises( + RuntimeError, + match="Duplicate compiled page route", + ), + ): + compile_ctx.compile() + + +def test_compile_context_requires_attached_context() -> None: + compile_ctx = CompileContext( + pages=[], + hooks=CompilerHooks(), + ) + + with pytest.raises( + RuntimeError, match="must be entered with 'with' or 'async with'" + ): + compile_ctx.compile() + + +def test_compile_context_memoize_wrappers_registers_shared_subtree_tag() -> None: + """Shared memoizable subtree across pages registers a single wrapper tag.""" + pages = [ + FakePage(route="/a", component=create_shared_stateful_component), + FakePage(route="/b", component=create_shared_stateful_component), + ] + compile_ctx = CompileContext( + pages=pages, + hooks=CompilerHooks(plugins=default_page_plugins()), + ) + + with compile_ctx: + compile_ctx.compile() + + # The wrapped library import still reaches the compile-context level. + assert "react-moment" in compile_ctx.all_imports + assert (25, "SharedLibraryWrap") in compile_ctx.app_wrap_components + # Both pages share the same subtree hash, so exactly one wrapper tag is registered. + assert len(compile_ctx.memoize_wrappers) == 1 + wrapper_tag = next(iter(compile_ctx.memoize_wrappers)) + assert list(compile_ctx.auto_memo_components) == [wrapper_tag] + # Each page imports the generated experimental memo component. + page_a_code = compile_ctx.compiled_pages["/a"].output_code or "" + assert ( + f'import {{{wrapper_tag}}} from "$/utils/components/{wrapper_tag}"' + in page_a_code + ) + assert f"jsx({wrapper_tag}," in page_a_code + assert f"const {wrapper_tag} = memo" not in page_a_code + # The removed shared-stateful-components path should not appear anywhere. + assert "$/utils/stateful_components" not in page_a_code + + +def test_compile_context_resets_memoize_wrappers_between_runs() -> None: + """``CompileContext.memoize_wrappers`` is cleared on each compile run.""" + ctx = CompileContext( + pages=[FakePage(route="/a", component=create_shared_stateful_component)], + hooks=CompilerHooks(plugins=default_page_plugins()), + ) + with ctx: + ctx.compile() + first_tags = set(ctx.memoize_wrappers) + first_defs = set(ctx.auto_memo_components) + assert first_tags # memoize wrapper was registered + assert first_defs == first_tags + + # Re-compile with a different page set → wrappers reset, not accumulated. + ctx2 = CompileContext( + pages=[FakePage(route="/c", component=create_shared_stateful_component)], + hooks=CompilerHooks(plugins=default_page_plugins()), + ) + with ctx2: + ctx2.compile() + + # Same shared component → same tag, not a union across runs. + assert set(ctx2.memoize_wrappers) == first_tags + assert set(ctx2.auto_memo_components) == first_tags + page_ctx = ctx2.compiled_pages["/c"] + assert "react-moment" in page_ctx.frontend_imports + assert "$/utils/stateful_components" not in (page_ctx.output_code or "") + + +def test_compile_context_applies_style_before_inline_stateful_render() -> None: + compile_ctx = CompileContext( + pages=[ + FakePage( + route="/styled", + component=create_inline_stateful_component, + ) + ], + hooks=CompilerHooks( + plugins=default_page_plugins( + style={InlineStatefulComponent: {"color": "red"}} + ) + ), + ) + + with compile_ctx: + compile_ctx.compile() + + assert '["color"] : "red"' in ( + compile_ctx.compiled_pages["/styled"].output_code or "" + ) + + +def test_compile_context_applies_style_before_shared_stateful_render() -> None: + compile_ctx = CompileContext( + pages=[ + FakePage(route="/a", component=create_shared_stateful_component), + FakePage(route="/b", component=create_shared_stateful_component), + ], + hooks=CompilerHooks( + plugins=default_page_plugins( + style={SharedLibraryComponent: {"color": "red"}} + ) + ), + ) + + with compile_ctx: + compile_ctx.compile() + + assert '["color"] : "red"' in (compile_ctx.compiled_pages["/a"].output_code or "") + assert '["color"] : "red"' in (compile_ctx.compiled_pages["/b"].output_code or "") diff --git a/tests/units/components/core/test_cond.py b/tests/units/components/core/test_cond.py index bf10dc968bd..a3c45875417 100644 --- a/tests/units/components/core/test_cond.py +++ b/tests/units/components/core/test_cond.py @@ -123,6 +123,19 @@ def test_cond_no_else(): cond(True, "hello") # pyright: ignore [reportArgumentType] +def test_cond_render_missing_false_child_defaults_to_fragment() -> None: + """Test Cond renders missing false branch as empty Fragment.""" + comp = Cond._create( + cond=LiteralVar.create(True).bool(), + children=[Fragment.create(Text.create("hello"))], + ) + + rendered = comp.render() + + assert rendered["true_value"] == Fragment.create(Text.create("hello")).render() + assert rendered["false_value"] == Fragment.create().render() + + def test_cond_computed_var(): """Test if cond works with computed vars.""" diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index de397442311..9ca3495a2a2 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -7,8 +7,8 @@ CUSTOM_COMPONENTS, Component, CustomComponent, - StatefulComponent, custom_component, + field, ) from reflex_base.constants import EventTriggers from reflex_base.constants.state import FIELD_MARKER @@ -522,30 +522,25 @@ def test_get_imports(component1, component2): } -def test_get_all_imports_includes_components_in_props(): - """Test that _get_all_imports collects imports from components in props.""" +def test_get_imports_includes_components_in_props(): + """Test that component-valued props contribute their imports.""" - class InnerComponent(Component): - """A component that requires a specific import.""" + class PropComponent(Component): + tag = "PropComponent" + library = "prop-lib" - def _get_imports(self) -> ParsedImportDict: - return {"some-library": [ImportVar(tag="SomeTag")]} - - class OuterComponent(Component): - """A component with a component-typed prop.""" + class ParentComponent(Component): + tag = "ParentComponent" + library = "parent-lib" - fallback: Component | None = None + slot: Component | None = field(default=None) - def _get_imports(self) -> ParsedImportDict: - return {"outer-library": [ImportVar(tag="OuterTag")]} + imports_ = ParentComponent.create(slot=PropComponent.create())._get_all_imports() - inner = InnerComponent.create() - outer = OuterComponent.create(fallback=inner) - all_imports = outer._get_all_imports() - assert "some-library" in all_imports, ( - "_get_all_imports() should collect imports from components in props" - ) - assert "outer-library" in all_imports + assert imports_ == parse_imports({ + "parent-lib": ["ParentComponent"], + "prop-lib": ["PropComponent"], + }) def test_get_custom_code(component1: Component, component2: Component): @@ -1191,47 +1186,6 @@ def test_format_component(component, rendered): assert str(component) == rendered -def test_stateful_component(test_state: type[TestState]): - """Test that a stateful component is created correctly. - - Args: - test_state: A test state. - """ - text_component = rx.text(test_state.num) - stateful_component = StatefulComponent.compile_from(text_component) - assert isinstance(stateful_component, StatefulComponent) - assert stateful_component.tag is not None - assert stateful_component.tag.startswith("Text_") - assert stateful_component.references == 1 - sc2 = StatefulComponent.compile_from(rx.text(test_state.num)) - assert isinstance(sc2, StatefulComponent) - assert stateful_component.references == 2 - assert sc2.references == 2 - - -def test_stateful_component_memoize_event_trigger(test_state: type[TestState]): - """Test that a stateful component is created correctly with events. - - Args: - test_state: A test state. - """ - button_component = rx.button("Click me", on_blur=test_state.do_something) - stateful_component = StatefulComponent.compile_from(button_component) - assert isinstance(stateful_component, StatefulComponent) - - # No event trigger? No StatefulComponent - assert not isinstance( - StatefulComponent.compile_from(rx.button("Click me")), StatefulComponent - ) - - -def test_stateful_banner(): - """Test that a stateful component is created correctly with events.""" - connection_modal_component = rx.connection_modal() - stateful_component = StatefulComponent.compile_from(connection_modal_component) - assert isinstance(stateful_component, StatefulComponent) - - TEST_VAR = LiteralVar.create("p")._replace( merge_var_data=VarData( hooks={"useTest": None}, @@ -1809,17 +1763,13 @@ class Inner(Component): tag = "Inner" library = "inner" - class Other(Component): - tag = "Other" - library = "other" - @rx.memo def wrapper(): return Inner.create() @rx.memo - def outer(c: Component): - return Other.create(c) + def outer(): + return wrapper() custom_comp = wrapper() @@ -1833,16 +1783,16 @@ def outer(c: Component): assert "inner" in imports_inner assert "outer" not in imports_inner - outer_comp = outer(c=wrapper()) + outer_comp = outer() - # Libraries are not imported directly, but are imported by the custom component. + # Nested custom components are only imported during compilation. assert "inner" not in outer_comp._get_all_imports() - assert "other" not in outer_comp._get_all_imports() # The imports are only resolved during compilation. _, imports_outer = compile_custom_component(outer_comp) assert "inner" not in imports_outer - assert "other" in imports_outer + assert "$/utils/components" in imports_outer + assert imports_outer["$/utils/components"] == [ImportVar(tag="Wrapper")] def test_custom_component_declare_event_handlers_in_fields(): diff --git a/tests/units/experimental/test_memo.py b/tests/units/experimental/test_memo.py index f202ecf05d8..99b547551da 100644 --- a/tests/units/experimental/test_memo.py +++ b/tests/units/experimental/test_memo.py @@ -101,9 +101,8 @@ def my_card( assert isinstance(definition, ExperimentalMemoComponentDefinition) assert any(str(prop) == "rest" for prop in definition.component.special_props) - _, code, _ = compiler.compile_memo_components( - (), tuple(EXPERIMENTAL_MEMOS.values()) - ) + files, _ = compiler.compile_memo_components((), tuple(EXPERIMENTAL_MEMOS.values())) + code = "\n".join(c for _, c in files) assert "export const MyCard = memo(({children, title:title" in code assert "...rest" in code assert "jsx(RadixThemesBox,{...rest}" in code @@ -126,9 +125,8 @@ def conditional_slot( "contents": "(showRxMemo ? firstRxMemo : secondRxMemo)" } - _, code, _ = compiler.compile_memo_components( - (), tuple(EXPERIMENTAL_MEMOS.values()) - ) + files, _ = compiler.compile_memo_components((), tuple(EXPERIMENTAL_MEMOS.values())) + code = "\n".join(c for _, c in files) assert "export const ConditionalSlot = memo(({show:showRxMemo" in code assert "(showRxMemo ? firstRxMemo : secondRxMemo)" in code @@ -151,9 +149,8 @@ def merge_styles( assert '["color"] : "red"' in str(merged) assert '["className"] : "primary"' in str(merged) - _, code, _ = compiler.compile_memo_components( - (), tuple(EXPERIMENTAL_MEMOS.values()) - ) + files, _ = compiler.compile_memo_components((), tuple(EXPERIMENTAL_MEMOS.values())) + code = "\n".join(c for _, c in files) assert ( "export const merge_styles = (({base, ...overrides}) => ({...base, ...overrides}));" in code @@ -185,9 +182,8 @@ def label_slot( assert '["children"]' in str(rendered) assert '["className"] : "slot"' in str(rendered) - _, code, _ = compiler.compile_memo_components( - (), tuple(EXPERIMENTAL_MEMOS.values()) - ) + files, _ = compiler.compile_memo_components((), tuple(EXPERIMENTAL_MEMOS.values())) + code = "\n".join(c for _, c in files) assert "export const label_slot = (({children, label, ...rest}) => label);" in code @@ -356,16 +352,79 @@ def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: def my_card(children: rx.Var[rx.Component], *, title: rx.Var[str]) -> rx.Component: return rx.box(rx.heading(title), children) - _, code, _ = compiler.compile_memo_components( + files, _ = compiler.compile_memo_components( dict.fromkeys(CUSTOM_COMPONENTS.values()), tuple(EXPERIMENTAL_MEMOS.values()), ) + code = "\n".join(c for _, c in files) assert "export const OldWrapper = memo(" in code assert "export const format_price =" in code assert "export const MyCard = memo(" in code +def test_compile_memo_components_extends_imports_without_remerging( + monkeypatch: pytest.MonkeyPatch, +): + """Memo import aggregation should not repeatedly reprocess prior imports.""" + + def noop() -> None: + pass + + memos = tuple( + ExperimentalMemoComponentDefinition( + fn=noop, + python_name=f"memo_{idx}", + params=(), + export_name=f"Memo{idx}", + component=rx.fragment(), + passthrough_hole_child=None, + ) + for idx in range(5) + ) + + def fake_compile_experimental_component_memo( + definition: ExperimentalMemoComponentDefinition, + ) -> tuple[dict[str, str], dict[str, list[ImportVar]]]: + return {"name": definition.export_name}, {} + + def fake_compile_single_memo_component( + component_render: dict[str, str], + component_imports: dict[str, list[ImportVar]], + ) -> tuple[str, dict[str, list[ImportVar]]]: + return ( + f"export const {component_render['name']} = null", + {"shared-lib": [ImportVar(tag=component_render["name"])]}, + ) + + real_merge_imports = compiler.utils.merge_imports + + def reject_growing_merge(*imports): + if len(imports) == 2 and imports[0]: + msg = "aggregate imports should be extended, not remerged" + raise AssertionError(msg) + return real_merge_imports(*imports) + + monkeypatch.setattr( + compiler.utils, + "compile_experimental_component_memo", + fake_compile_experimental_component_memo, + ) + monkeypatch.setattr( + compiler, + "_compile_single_memo_component", + fake_compile_single_memo_component, + ) + monkeypatch.setattr(compiler.utils, "merge_imports", reject_growing_merge) + + files, aggregate_imports = compiler.compile_memo_components((), memos) + + assert len(files) == len(memos) + 1 + assert [import_var.tag for import_var in aggregate_imports["shared-lib"]] == [ + f"Memo{idx}" for idx in range(5) + ] + + def test_experimental_component_memo_get_imports(): """Experimental component memos should resolve imports during compilation.""" @@ -415,6 +474,29 @@ def wrapper() -> rx.Component: assert definition.component.style == Style() +def test_component_returning_memo_is_transparent_for_child_validation(): + """Experimental memo wrappers should not break `_valid_parents` checks.""" + + class ValidParent(Component): + tag = "ValidParent" + library = "valid-parent" + + class RestrictedChild(Component): + tag = "RestrictedChild" + library = "restricted-child" + _valid_parents = ["ValidParent"] + + @rx._x.memo + def transparent(children: rx.Var[rx.Component]) -> rx.Component: + return children # type: ignore[return-value] + + wrapped_child = transparent(RestrictedChild.create()) + parent = ValidParent.create(wrapped_child) + + assert isinstance(wrapped_child, ExperimentalMemoComponent) + assert parent.children == [wrapped_child] + + def test_compile_memo_components_includes_experimental_custom_code(): """Experimental component memos should include custom code in compiled output.""" @@ -428,8 +510,7 @@ def add_custom_code(self) -> list[str]: def foo_component(label: rx.Var[str]) -> rx.Component: return FooComponent.create(label, rx.Var("foo")) - _, code, _ = compiler.compile_memo_components( - (), tuple(EXPERIMENTAL_MEMOS.values()) - ) + files, _ = compiler.compile_memo_components((), tuple(EXPERIMENTAL_MEMOS.values())) + code = "\n".join(c for _, c in files) assert "const foo = 'bar'" in code diff --git a/tests/units/plugins/test_tailwind.py b/tests/units/plugins/test_tailwind.py new file mode 100644 index 00000000000..3f40b960172 --- /dev/null +++ b/tests/units/plugins/test_tailwind.py @@ -0,0 +1,41 @@ +from pathlib import Path + +import pytest +from reflex_base.plugins import tailwind_v3, tailwind_v4 + + +@pytest.mark.parametrize("module", [tailwind_v3, tailwind_v4]) +def test_compile_root_style_omits_radix_when_disabled(module): + """Tailwind root styles should omit the Radix import when disabled.""" + _, code = module.compile_root_style(include_radix_themes=False) + + assert "@radix-ui/themes/styles.css" not in code + + +@pytest.mark.parametrize("module", [tailwind_v3, tailwind_v4]) +def test_add_tailwind_to_css_file_inserts_import_without_radix(module): + """Tailwind should still be added when the root stylesheet has no Radix import.""" + css = ( + "@layer __reflex_base;\n" + "@import url('./__reflex_style_reset.css');\n" + "@import url('./style.css');" + ) + + updated_css = module.add_tailwind_to_css_file( + css, + include_radix_themes=False, + ) + + assert updated_css.splitlines() == [ + "@layer __reflex_base;", + "@import url('./__reflex_style_reset.css');", + "@import url('./tailwind.css');", + "@import url('./style.css');", + ] + + +def test_v3_compile_root_style_keeps_expected_output_path(): + """Tailwind v3 should continue writing to the shared tailwind.css path.""" + output_path, _ = tailwind_v3.compile_root_style(include_radix_themes=False) + + assert output_path == str(Path("styles") / "tailwind.css") diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 077b7babb07..a772b3a3eae 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -1319,72 +1319,6 @@ async def send(_message): assert bio.closed -@pytest.mark.asyncio -async def test_upload_file_cancels_buffered_handler_on_disconnect(token: str): - """Buffered uploads cancel the streaming handler on client disconnect. - - Args: - token: A token. - """ - request_mock = unittest.mock.Mock() - request_mock.headers = { - "reflex-client-token": token, - "reflex-event-handler": f"{FileUploadState.get_full_name()}.multi_handle_upload", - } - - bio = io.BytesIO(b"contents of image one") - file1 = UploadFile(filename="image1.jpg", file=bio) - form_data = FormData([("files", file1)]) - original_close = form_data.close - form_close = AsyncMock(side_effect=original_close) - form_data.close = form_close - - async def form(): # noqa: RUF029 - return form_data - - request_mock.form = form - - stream_started = asyncio.Event() - stream_closed = asyncio.Event() - - async def enqueue_stream_delta(_token, _event): - try: - stream_started.set() - yield {"state": {"ok": True}} - await asyncio.Event().wait() - finally: - stream_closed.set() - - app = Mock( - event_processor=Mock(enqueue_stream_delta=enqueue_stream_delta), - ) - - upload_fn = upload(app) - streaming_response = await upload_fn(request_mock) - - assert isinstance(streaming_response, StreamingResponse) - - async def receive(): - await stream_started.wait() - return {"type": "http.disconnect"} - - async def send(_message): # noqa: RUF029 - return None - - await asyncio.wait_for( - streaming_response( - {"type": "http", "asgi": {"spec_version": "2.4"}}, - receive, - send, - ), - timeout=1, - ) - - await asyncio.wait_for(stream_closed.wait(), timeout=1) - assert form_close.await_count == 1 - assert bio.closed - - @pytest.mark.asyncio async def test_upload_file_raises_client_disconnect_when_stream_send_fails( token: str, @@ -2174,6 +2108,184 @@ def test_app_wrap_compile_theme( assert expected.split(",") == function_app_definition.split(",") +def test_compile_without_radix_components_skips_radix_plugin( + compilable_app: tuple[App, Path], + mocker: MockerFixture, +): + """Pure HTML apps should not include Radix Themes assets or wrappers.""" + conf = rx.Config(app_name="testing") + mocker.patch("reflex_base.config._get_config", return_value=conf) + app, web_dir = compilable_app + mocker.patch("reflex.utils.prerequisites.get_web_dir", return_value=web_dir) + mock_deprecate = mocker.patch("reflex_base.utils.console.deprecate") + + app.add_page(lambda: rx.el.div("Index"), route="/") + app.add_page(lambda: rx.el.div("404"), route=constants.Page404.SLUG) + app._compile() + + root_stylesheet = ( + web_dir + / constants.Dirs.STYLES + / f"{constants.PageNames.STYLESHEET_ROOT}{constants.Ext.CSS}" + ).read_text() + app_root = ( + web_dir / constants.Dirs.PAGES / constants.PageNames.APP_ROOT + ).read_text() + + assert "@radix-ui/themes/styles.css" not in root_stylesheet + assert "RadixThemesTheme" not in app_root + mock_deprecate.assert_not_called() + + +def test_compile_with_radix_component_auto_enables_radix_plugin( + compilable_app: tuple[App, Path], + mocker: MockerFixture, +): + """Using a Radix Themes component should enable the plugin with a warning.""" + conf = rx.Config(app_name="testing") + mocker.patch("reflex_base.config._get_config", return_value=conf) + app, web_dir = compilable_app + mocker.patch("reflex.utils.prerequisites.get_web_dir", return_value=web_dir) + mock_deprecate = mocker.patch("reflex_base.utils.console.deprecate") + + app.add_page(lambda: rx.box("Index"), route="/") + app.add_page(lambda: rx.el.div("404"), route=constants.Page404.SLUG) + app._compile() + + root_stylesheet = ( + web_dir + / constants.Dirs.STYLES + / f"{constants.PageNames.STYLESHEET_ROOT}{constants.Ext.CSS}" + ).read_text() + app_root = ( + web_dir / constants.Dirs.PAGES / constants.PageNames.APP_ROOT + ).read_text() + + assert "@radix-ui/themes/styles.css" in root_stylesheet + assert 'RadixThemesTheme,{accentColor:"blue"' in app_root + mock_deprecate.assert_called_once() + assert ( + mock_deprecate.call_args.kwargs["feature_name"] + == "Implicit Radix Themes enablement" + ) + + +def test_compile_with_legacy_app_theme_warns_and_enables_radix_plugin( + compilable_app: tuple[App, Path], + mocker: MockerFixture, +): + """``App(theme=...)`` should continue to work with a deprecation warning.""" + conf = rx.Config(app_name="testing") + mocker.patch("reflex_base.config._get_config", return_value=conf) + app, web_dir = compilable_app + mocker.patch("reflex.utils.prerequisites.get_web_dir", return_value=web_dir) + mock_deprecate = mocker.patch("reflex_base.utils.console.deprecate") + + app.theme = rx.theme(accent_color="plum") + app.add_page(lambda: rx.el.div("Index"), route="/") + app.add_page(lambda: rx.el.div("404"), route=constants.Page404.SLUG) + app._compile() + + root_stylesheet = ( + web_dir + / constants.Dirs.STYLES + / f"{constants.PageNames.STYLESHEET_ROOT}{constants.Ext.CSS}" + ).read_text() + app_root = ( + web_dir / constants.Dirs.PAGES / constants.PageNames.APP_ROOT + ).read_text() + + assert "@radix-ui/themes/styles.css" in root_stylesheet + assert 'RadixThemesTheme,{accentColor:"plum"' in app_root + mock_deprecate.assert_called_once() + assert mock_deprecate.call_args.kwargs["feature_name"] == "App(theme=...)" + + +def test_explicit_radix_plugin_wins_over_legacy_app_theme( + compilable_app: tuple[App, Path], + mocker: MockerFixture, +): + """Explicit RadixThemesPlugin config should win over deprecated App.theme.""" + conf = rx.Config( + app_name="testing", + plugins=[rx.plugins.RadixThemesPlugin(theme=rx.theme(accent_color="green"))], + ) + mocker.patch("reflex_base.config._get_config", return_value=conf) + app, web_dir = compilable_app + mocker.patch("reflex.utils.prerequisites.get_web_dir", return_value=web_dir) + mock_deprecate = mocker.patch("reflex_base.utils.console.deprecate") + + app.theme = rx.theme(accent_color="plum") + app.add_page(lambda: rx.el.div("Index"), route="/") + app.add_page(lambda: rx.el.div("404"), route=constants.Page404.SLUG) + app._compile() + + app_root = ( + web_dir / constants.Dirs.PAGES / constants.PageNames.APP_ROOT + ).read_text() + + assert 'RadixThemesTheme,{accentColor:"green"' in app_root + assert 'RadixThemesTheme,{accentColor:"plum"' not in app_root + mock_deprecate.assert_called_once() + assert mock_deprecate.call_args.kwargs["feature_name"] == "App(theme=...)" + + +def test_compile_writes_app_wrap_memo_components( + compilable_app: tuple[App, Path], + mocker, +) -> None: + """App-wrap memo components are emitted to the shared components module.""" + conf = rx.Config(app_name="testing") + mocker.patch("reflex_base.config._get_config", return_value=conf) + app, web_dir = compilable_app + + app.add_page(rx.box("Index"), route="/") + app._compile() + + components_index = ( + web_dir + / constants.Dirs.UTILS + / f"{constants.PageNames.COMPONENTS}{constants.Ext.JSX}" + ).read_text() + + # Per-memo modules live under .web/utils/components/; the index re-exports + # each one so page-side ``$/utils/components`` resolves the same tags. + assert "DefaultOverlayComponents" in components_index + assert "MemoizedToastProvider" in components_index + assert 'from "./components/DefaultOverlayComponents"' in components_index + assert 'from "./components/MemoizedToastProvider"' in components_index + + memo_dir = web_dir / constants.Dirs.UTILS / constants.PageNames.COMPONENTS + assert (memo_dir / f"DefaultOverlayComponents{constants.Ext.JSX}").exists() + assert (memo_dir / f"MemoizedToastProvider{constants.Ext.JSX}").exists() + + +def test_compile_writes_upload_files_provider_app_wrap( + compilable_app: tuple[App, Path], + mocker, +) -> None: + """Upload pages emit the UploadFilesProvider app wrap into the app root.""" + conf = rx.Config(app_name="testing") + mocker.patch("reflex_base.config._get_config", return_value=conf) + app, web_dir = compilable_app + + app.add_page( + lambda: rx.upload.root( + rx.vstack( + rx.button("Select File"), + rx.text("Drag and drop files here or click to select files"), + ), + ), + route="/", + ) + app._compile() + + root_js = web_dir / constants.Dirs.PAGES / constants.PageNames.APP_ROOT + root_contents = root_js.read_text() + + assert "UploadFilesProvider" in root_contents + + @pytest.mark.parametrize( "react_strict_mode", [True, False], diff --git a/tests/units/test_environment.py b/tests/units/test_environment.py index 59e3c5ed438..8e7796b3ebe 100644 --- a/tests/units/test_environment.py +++ b/tests/units/test_environment.py @@ -12,7 +12,6 @@ from reflex_base.environment import ( EnvironmentVariables, EnvVar, - ExecutorType, ExistingPath, PerformanceMode, SequenceOptions, @@ -408,47 +407,6 @@ class TestEnv: assert env_var_instance.default == "default" -class TestExecutorType: - """Test the ExecutorType enum and related functionality.""" - - def test_executor_type_values(self): - """Test ExecutorType enum values.""" - assert ExecutorType.THREAD.value == "thread" - assert ExecutorType.PROCESS.value == "process" - assert ExecutorType.MAIN_THREAD.value == "main_thread" - - def test_get_executor_main_thread_mode(self): - """Test executor selection in main thread mode.""" - with ( - patch.object( - environment.REFLEX_COMPILE_EXECUTOR, - "get", - return_value=ExecutorType.MAIN_THREAD, - ), - patch.object( - environment.REFLEX_COMPILE_PROCESSES, "get", return_value=None - ), - patch.object(environment.REFLEX_COMPILE_THREADS, "get", return_value=None), - ): - executor = ExecutorType.get_executor_from_environment() - - # Test the main thread executor functionality - with executor: - future = executor.submit(lambda x: x * 2, 5) - assert future.result() == 10 - - def test_get_executor_returns_executor(self): - """Test that get_executor_from_environment returns an executor.""" - # Test with default values - should return some kind of executor - executor = ExecutorType.get_executor_from_environment() - assert executor is not None - - # Test that we can use it as a context manager - with executor: - future = executor.submit(lambda: "test") - assert future.result() == "test" - - class TestUtilityFunctions: """Test utility functions.""" diff --git a/tests/units/utils/test_streaming_response.py b/tests/units/utils/test_streaming_response.py index 122af46a9b7..9ee09a7801c 100644 --- a/tests/units/utils/test_streaming_response.py +++ b/tests/units/utils/test_streaming_response.py @@ -9,47 +9,11 @@ from starlette.requests import ClientDisconnect -@pytest.mark.asyncio -async def test_disconnect_cancels_stream_task_and_runs_finish(): - """A receive-side disconnect cancels the body stream and cleanup runs once.""" - body_closed = asyncio.Event() - body_started = asyncio.Event() - on_finish = AsyncMock() - - async def body(): - try: - body_started.set() - yield b"payload" - await asyncio.Event().wait() - finally: - body_closed.set() - - async def receive(): - await body_started.wait() - return {"type": "http.disconnect"} - - async def send(_message): - await asyncio.sleep(0) - - response = DisconnectAwareStreamingResponse( - body(), - media_type="application/x-ndjson", - on_finish=on_finish, - ) - - await asyncio.wait_for( - response({"type": "http", "asgi": {"spec_version": "2.4"}}, receive, send), - timeout=1, - ) - - await asyncio.wait_for(body_closed.wait(), timeout=1) - on_finish.assert_awaited_once() - - @pytest.mark.asyncio async def test_send_oserror_raises_client_disconnect_and_closes_body(): """A send-side disconnect still raises ClientDisconnect and closes the stream.""" body_closed = asyncio.Event() + disconnect_notified = asyncio.Event() on_finish = AsyncMock() async def body(): @@ -74,12 +38,14 @@ async def send(message): body(), media_type="application/x-ndjson", on_finish=on_finish, + on_disconnect=disconnect_notified.set, ) with pytest.raises(ClientDisconnect): await response({"type": "http", "asgi": {"spec_version": "2.4"}}, receive, send) await asyncio.wait_for(body_closed.wait(), timeout=1) + assert disconnect_notified.is_set() on_finish.assert_awaited_once()