diff --git a/packages/reflex-base/src/reflex_base/compiler/templates.py b/packages/reflex-base/src/reflex_base/compiler/templates.py index b4056227987..3c525e99f4e 100644 --- a/packages/reflex-base/src/reflex_base/compiler/templates.py +++ b/packages/reflex-base/src/reflex_base/compiler/templates.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from reflex.compiler.utils import _ImportDict - from reflex_base.components.component import Component, StatefulComponent + from reflex_base.components.component import Component def _sort_hooks( @@ -348,7 +348,12 @@ def context_template( export const initialState = {"{}" if not initial_state else json_dumps(initial_state)} export const defaultColorMode = {default_color_mode} -export const ColorModeContext = createContext(null); +export const ColorModeContext = createContext({{ + colorMode: defaultColorMode, + resolvedColorMode: defaultColorMode === "dark" ? "dark" : "light", + toggleColorMode: () => {{}}, + setColorMode: () => {{}}, +}}); export const UploadFilesContext = createContext(null); export const DispatchContext = createContext(null); export const StateContexts = {{{state_contexts_str}}}; @@ -417,7 +422,7 @@ def context_template( }}""" -def component_template(component: Component | StatefulComponent): +def component_template(component: Component): """Template to render a component tag. Args: @@ -618,24 +623,23 @@ def vite_config_template( }}));""" -def stateful_component_template( - tag_name: str, memo_trigger_hooks: list[str], component: Component, export: bool -): - """Template for stateful component. +def dynamic_component_template( + tag_name: str, component: Component, export: bool +) -> str: + """Template for a dynamic SSR component function declaration. Args: tag_name: The tag name for the component. - memo_trigger_hooks: The memo trigger hooks for the component. component: The component to render. export: Whether to export the component. Returns: - Rendered stateful component code as string. + Rendered dynamic component code as string. """ all_hooks = component._get_all_hooks() return f""" {"export " if export else ""}function {tag_name} () {{ - {_render_hooks(all_hooks, memo_trigger_hooks)} + {_render_hooks(all_hooks)} return ( {_RenderUtils.render(component.render())} ) @@ -643,15 +647,17 @@ def stateful_component_template( """ -def stateful_components_template(imports: list[_ImportDict], memoized_code: str) -> str: - """Template for stateful components. +def dynamic_components_module_template( + imports: list[_ImportDict], memoized_code: str +) -> str: + """Template for a dynamic-SSR components module. Args: imports: List of import statements. - memoized_code: Memoized code for stateful components. + memoized_code: Code for the module body. Returns: - Rendered stateful components code as string. + Rendered module code as string. """ imports_str = "\n".join([_RenderUtils.get_import(imp) for imp in imports]) return f"{imports_str}\n{memoized_code}" @@ -709,6 +715,83 @@ def memo_components_template( {components_code}""" +def memo_single_component_template( + imports: list[_ImportDict], + component: dict[str, Any], + dynamic_imports: Iterable[str], + custom_codes: Iterable[str], +) -> str: + """Template for a single memoized component in its own module. + + Args: + imports: List of import statements for this memo only. + component: The single component definition to render. + dynamic_imports: Dynamic import statements scoped to this memo. + custom_codes: Custom code snippets scoped to this memo. + + Returns: + The rendered standalone memo module code. + """ + imports_str = "\n".join([_RenderUtils.get_import(imp) for imp in imports]) + dynamic_imports_str = "\n".join(dynamic_imports) + custom_code_str = "\n".join(custom_codes) + + component_code = f""" +export const {component["name"]} = memo(({component["signature"]}) => {{ + {_render_hooks(component.get("hooks", {}))} + return( + {_RenderUtils.render(component["render"])} + ) +}}); +""" + + return f""" +{imports_str} + +{dynamic_imports_str} + +{custom_code_str} + +{component_code}""" + + +def memo_single_function_template( + imports: list[_ImportDict], + function: dict[str, Any], +) -> str: + """Template for a single function memo in its own module. + + Args: + imports: List of import statements for this memo only. + function: The single function memo definition. + + Returns: + The rendered standalone function memo module code. + """ + imports_str = "\n".join([_RenderUtils.get_import(imp) for imp in imports]) + return f""" +{imports_str} + +export const {function["name"]} = {function["function"]}; +""" + + +def memo_index_template(reexports: Iterable[tuple[str, str]]) -> str: + """Template for the memo index module that re-exports every memo file. + + Args: + reexports: Iterable of ``(export_name, relative_module_specifier)``. + + Returns: + The rendered index module code. + """ + lines = [ + f'export {{ {export_name} }} from "{specifier}";' + for export_name, specifier in reexports + ] + return "\n".join(lines) + "\n" + + def styles_template(stylesheets: list[str]) -> str: """Template for styles.css. diff --git a/packages/reflex-base/src/reflex_base/components/component.py b/packages/reflex-base/src/reflex_base/components/component.py index 8f47f447c7d..79b562ddc78 100644 --- a/packages/reflex-base/src/reflex_base/components/component.py +++ b/packages/reflex-base/src/reflex_base/components/component.py @@ -3,11 +3,11 @@ from __future__ import annotations import contextlib -import copy import dataclasses import enum import functools import inspect +import operator import typing from abc import ABC, ABCMeta, abstractmethod from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence @@ -22,19 +22,10 @@ from reflex_base import constants from reflex_base.breakpoints import Breakpoints -from reflex_base.compiler.templates import stateful_component_template from reflex_base.components.dynamic import load_dynamic_serializer from reflex_base.components.field import BaseField, FieldBasedMeta from reflex_base.components.tags import Tag -from reflex_base.constants import ( - Dirs, - EventTriggers, - Hooks, - Imports, - MemoizationDisposition, - MemoizationMode, - PageNames, -) +from reflex_base.constants import Dirs, EventTriggers, Hooks, Imports, MemoizationMode from reflex_base.constants.compiler import SpecialAttributes from reflex_base.constants.state import CAMEL_CASE_MEMO_MARKER from reflex_base.event import ( @@ -321,6 +312,31 @@ def set(self, **kwargs): setattr(self, key, value) return self + def __copy__(self) -> BaseComponent: + """Return a shallow copy suitable for compile-time mutation. + + Bypasses ``copy.copy``'s generic ``__reduce_ex__`` dispatch. Nested + mutable containers (``children``, ``style``, ``event_triggers``) are + shared with the original until the caller explicitly rebinds them. + Render-path caches populated on the original are dropped so the clone + recomputes against its (potentially rebound) fields. + + Returns: + A new instance of the same class with ``__dict__`` shallow-copied. + """ + new = self.__class__.__new__(self.__class__) + new_dict = vars(new) + new_dict.update(vars(self)) + for attr in ( + "_cached_render_result", + "_vars_cache", + "_imports_cache", + "_hooks_internal_cache", + "_get_component_prop_property", + ): + new_dict.pop(attr, None) + return new + def __eq__(self, value: Any) -> bool: """Check if the component is equal to another value. @@ -502,14 +518,65 @@ def _hash_str(value: str) -> str: return md5(f'"{value}"'.encode(), usedforsecurity=False).hexdigest() -def _hash_sequence(value: Sequence) -> str: - return _hash_str(str([_deterministic_hash(v) for v in value])) +def _update_deterministic_hash(hasher: Any, value: object) -> None: + """Feed ``value`` into ``hasher`` using a self-delimiting, type-tagged encoding. + Each branch writes a distinct type tag plus length-prefixed payload, which + keeps the encoding injective without building intermediate strings — the + nested ``str([...])`` approach this replaces was the dominant cost of + ``_deterministic_hash`` (~4x speedup on synthetic, ~2x on real renders). -def _hash_dict(value: dict) -> str: - return _hash_sequence( - sorted([(k, _deterministic_hash(v)) for k, v in value.items()]) - ) + Args: + hasher: A ``hashlib`` hasher (must accept ``.update(bytes)``). + value: The value to fold into the hasher. + + Raises: + TypeError: If the value is not hashable. + """ + if value is None: + hasher.update(b"N") + elif isinstance(value, bool): + hasher.update(b"T" if value else b"F") + elif isinstance(value, (int, float, enum.Enum)): + hasher.update(b"n") + hasher.update(str(value).encode()) + elif isinstance(value, str): + encoded = value.encode() + hasher.update(b"s") + hasher.update(len(encoded).to_bytes(8, "little")) + hasher.update(encoded) + elif isinstance(value, dict): + items = sorted(value.items(), key=operator.itemgetter(0)) + hasher.update(b"d") + hasher.update(len(items).to_bytes(8, "little")) + for k, v in items: + _update_deterministic_hash(hasher, k) + _update_deterministic_hash(hasher, v) + elif isinstance(value, (tuple, list)): + hasher.update(b"l") + hasher.update(len(value).to_bytes(8, "little")) + for item in value: + _update_deterministic_hash(hasher, item) + elif isinstance(value, Var): + hasher.update(b"v") + _update_deterministic_hash(hasher, value._js_expr) + _update_deterministic_hash(hasher, value._get_all_var_data()) + elif dataclasses.is_dataclass(value): + fields = dataclasses.fields(value) + hasher.update(b"D") + hasher.update(len(fields).to_bytes(8, "little")) + for field in fields: + hasher.update(field.name.encode()) + _update_deterministic_hash(hasher, getattr(value, field.name)) + elif isinstance(value, BaseComponent): + hasher.update(b"C") + _update_deterministic_hash(hasher, value.render()) + else: + msg = ( + f"Cannot hash value `{value}` of type `{type(value).__name__}`. " + "Only BaseComponent, Var, VarData, dict, str, tuple, and enum.Enum are supported." + ) + raise TypeError(msg) def _deterministic_hash(value: object) -> str: @@ -524,36 +591,9 @@ def _deterministic_hash(value: object) -> str: Raises: TypeError: If the value is not hashable. """ - if value is None: - # Hash None as a special case. - return "None" - if isinstance(value, (int, float, enum.Enum)): - # Hash numbers and booleans directly. - return str(value) - if isinstance(value, str): - return _hash_str(value) - if isinstance(value, dict): - return _hash_dict(value) - if isinstance(value, (tuple, list)): - # Hash tuples by hashing each element. - return _hash_sequence(value) - if isinstance(value, Var): - return _hash_str( - str((value._js_expr, _deterministic_hash(value._get_all_var_data()))) - ) - if dataclasses.is_dataclass(value): - return _hash_dict({ - k.name: getattr(value, k.name) for k in dataclasses.fields(value) - }) - if isinstance(value, BaseComponent): - # If the value is a component, hash its rendered code. - return _hash_dict(value.render()) - - msg = ( - f"Cannot hash value `{value}` of type `{type(value).__name__}`. " - "Only BaseComponent, Var, VarData, dict, str, tuple, and enum.Enum are supported." - ) - raise TypeError(msg) + hasher = md5(usedforsecurity=False) + _update_deterministic_hash(hasher, value) + return hasher.hexdigest() @dataclasses.dataclass(kw_only=True, frozen=True, slots=True) @@ -1300,7 +1340,7 @@ def _add_style_recursive( # Recursively add style to the children. for child in self.children: - # Skip BaseComponent and StatefulComponent children. + # Skip non-Component children. if not isinstance(child, Component): continue child._add_style_recursive(style, theme) @@ -1327,6 +1367,10 @@ def render(self) -> dict: Returns: The dictionary for template of component. """ + try: + return self._cached_render_result + except AttributeError: + pass tag = self._render() rendered_dict = dict( tag.set( @@ -1334,6 +1378,7 @@ def render(self) -> dict: ) ) self._replace_prop_names(rendered_dict) + self._cached_render_result = rendered_dict return rendered_dict def _replace_prop_names(self, rendered_dict: dict) -> None: @@ -1457,11 +1502,14 @@ def _get_vars( Yields: Each var referenced by the component (props, styles, event handlers). """ + if not include_children and ignore_ids is None: + cached = self.__dict__.get("_vars_cache") + if cached is not None: + yield from cached + return + ignore_ids = ignore_ids or set() - vars: list[Var] | None = getattr(self, "__vars", None) - if vars is not None: - yield from vars - vars = self.__vars = [] + vars: list[Var] = [] # Get Vars associated with event trigger arguments. for _, event_vars in self._get_vars_from_event_triggers(self.event_triggers): vars.extend(event_vars) @@ -1500,17 +1548,19 @@ def _get_vars( if var._get_all_var_data() is not None: vars.append(var) - # Get Vars associated with children. if include_children: + yield from vars for child in self.children: if not isinstance(child, Component) or id(child) in ignore_ids: continue ignore_ids.add(id(child)) - child_vars = child._get_vars( + yield from child._get_vars( include_children=include_children, ignore_ids=ignore_ids ) - vars.extend(child_vars) + return + # Freeze and cache the default-args result. + self._vars_cache = tuple(vars) yield from vars def _event_trigger_values_use_state(self) -> bool: @@ -1555,6 +1605,7 @@ def _iter_parent_classes_names(cls) -> Iterator[str]: yield clz.__name__ @classmethod + @functools.cache def _iter_parent_classes_with_method(cls, method: str) -> Sequence[type[Component]]: """Iterate through parent classes that define a given method. @@ -1581,7 +1632,7 @@ def _iter_parent_classes_with_method(cls, method: str) -> Sequence[type[Componen continue seen_methods.add(method_func) clzs.append(clz) - return clzs + return tuple(clzs) def _get_custom_code(self) -> str | None: """Get custom code for the component. @@ -1704,6 +1755,10 @@ def _get_imports(self) -> ParsedImportDict: Returns: The imports needed by the component. """ + cached = self.__dict__.get("_imports_cache") + if cached is not None: + return cached + imports_ = ( {self.library: [self.import_var]} if self.library is not None and self.tag is not None @@ -1731,7 +1786,7 @@ def _get_imports(self) -> ParsedImportDict: imports.parse_imports(item) for item in list_of_import_dict ]) - return imports.merge_parsed_imports( + result = imports.merge_parsed_imports( self._get_dependencies_imports(), self._get_hooks_imports(), imports_, @@ -1739,6 +1794,8 @@ def _get_imports(self) -> ParsedImportDict: *var_imports, *added_import_dicts, ) + self._imports_cache = result + return result def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict: """Get all the libraries and fields that are used by the component and its children. @@ -1840,15 +1897,21 @@ def _get_hooks_internal(self) -> dict[str, VarData | None]: Returns: The internally managed hooks. """ - return { + cached = self.__dict__.get("_hooks_internal_cache") + if cached is not None: + return cached + + result = { + **self._get_events_hooks(), **{ str(hook): VarData(position=Hooks.HookPosition.INTERNAL) for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()] if hook is not None }, **self._get_vars_hooks(), - **self._get_events_hooks(), } + self._hooks_internal_cache = result + return result def _get_added_hooks(self) -> dict[str, VarData | None]: """Get the hooks added via `add_hooks` method. @@ -2005,7 +2068,7 @@ def _get_all_app_wrap_components( # Add the app wrap components for the children. for child in self.children: child_id = id(child) - # Skip BaseComponent and StatefulComponent children. + # Skip non-Component children. if not isinstance(child, Component) or child_id in ignore_ids: continue ignore_ids.add(child_id) @@ -2382,440 +2445,6 @@ def _get_dynamic_imports(self) -> str: ) -class StatefulComponent(BaseComponent): - """A component that depends on state and is rendered outside of the page component. - - If a StatefulComponent is used in multiple pages, it will be rendered to a common file and - imported into each page that uses it. - - A stateful component has a tag name that includes a hash of the code that it renders - to. This tag name refers to the specific component with the specific props that it - was created with. - """ - - # Reference to the original component that was memoized into this component. - component: Component = field( - default_factory=Component, is_javascript_property=False - ) - - references: int = field( - doc="How many times this component is referenced in the app.", - default=0, - is_javascript_property=False, - ) - - rendered_as_shared: bool = field( - doc="Whether the component has already been rendered to a shared file.", - default=False, - is_javascript_property=False, - ) - - memo_trigger_hooks: list[str] = field( - default_factory=list, is_javascript_property=False - ) - - @classmethod - def create(cls, component: Component) -> StatefulComponent | None: - """Create a stateful component from a component. - - Args: - component: The component to memoize. - - Returns: - The stateful component or None if the component should not be memoized. - """ - from reflex_components_core.core.foreach import Foreach - - from reflex_base.registry import RegistrationContext - - if component._memoization_mode.disposition == MemoizationDisposition.NEVER: - # Never memoize this component. - return None - - if component.tag is None: - # Only memoize components with a tag. - return None - - # If _var_data is found in this component, it is a candidate for auto-memoization. - should_memoize = False - - # If the component requests to be memoized, then ignore other checks. - if component._memoization_mode.disposition == MemoizationDisposition.ALWAYS: - should_memoize = True - - if not should_memoize: - # Determine if any Vars have associated data. - for prop_var in component._get_vars(include_children=True): - if prop_var._get_all_var_data(): - should_memoize = True - break - - if not should_memoize: - # Check for special-cases in child components. - for child in component.children: - # Skip BaseComponent and StatefulComponent children. - if not isinstance(child, Component): - continue - # Always consider Foreach something that must be memoized by the parent. - if isinstance(child, Foreach): - should_memoize = True - break - child = cls._child_var(child) - if isinstance(child, Var) and child._get_all_var_data(): - should_memoize = True - break - - if should_memoize or component.event_triggers: - # Render the component to determine tag+hash based on component code. - tag_name = cls._get_tag_name(component) - if tag_name is None: - return None - - # Look up the tag in the cache - ctx = RegistrationContext.get() - stateful_component = ctx.tag_to_stateful_component.get(tag_name) - if stateful_component is None: - memo_trigger_hooks = cls._fix_event_triggers(component) - # Set the stateful component in the cache for the given tag. - stateful_component = ctx.tag_to_stateful_component.setdefault( - tag_name, - cls( - children=component.children, - component=component, - tag=tag_name, - memo_trigger_hooks=memo_trigger_hooks, - ), - ) - # Bump the reference count -- multiple pages referencing the same component - # will result in writing it to a common file. - stateful_component.references += 1 - return stateful_component - - # Return None to indicate this component should not be memoized. - return None - - @staticmethod - def _child_var(child: Component) -> Var | Component: - """Get the Var from a child component. - - This method is used for special cases when the StatefulComponent should actually - wrap the parent component of the child instead of recursing into the children - and memoizing them independently. - - Args: - child: The child component. - - Returns: - The Var from the child component or the child itself (for regular cases). - """ - from reflex_components_core.base.bare import Bare - from reflex_components_core.core.cond import Cond - from reflex_components_core.core.foreach import Foreach - from reflex_components_core.core.match import Match - - if isinstance(child, Bare): - return child.contents - if isinstance(child, Cond): - return child.cond - if isinstance(child, Foreach): - return child.iterable - if isinstance(child, Match): - return child.cond - return child - - @classmethod - def _get_tag_name(cls, component: Component) -> str | None: - """Get the tag based on rendering the given component. - - Args: - component: The component to render. - - Returns: - The tag for the stateful component. - """ - # Get the render dict for the component. - rendered_code = component.render() - if not rendered_code: - # Never memoize non-visual components. - return None - - # Compute the hash based on the rendered code. - code_hash = _hash_str(_deterministic_hash(rendered_code)) - - # Format the tag name including the hash. - return format.format_state_name( - f"{component.tag or 'Comp'}_{code_hash}" - ).capitalize() - - def _render_stateful_code( - self, - export: bool = False, - ) -> str: - if not self.tag: - return "" - # Render the code for this component and hooks. - return stateful_component_template( - tag_name=self.tag, - memo_trigger_hooks=self.memo_trigger_hooks, - component=self.component, - export=export, - ) - - @classmethod - def _fix_event_triggers( - cls, - component: Component, - ) -> list[str]: - """Render the code for a stateful component. - - Args: - component: The component to render. - - Returns: - The memoized event trigger hooks for the component. - """ - # Memoize event triggers useCallback to avoid unnecessary re-renders. - memo_event_triggers = tuple(cls._get_memoized_event_triggers(component).items()) - - # Trigger hooks stored separately to write after the normal hooks (see stateful_component.js.jinja2) - memo_trigger_hooks: list[str] = [] - - if memo_event_triggers: - # Copy the component to avoid mutating the original. - component = copy.copy(component) - - for event_trigger, ( - memo_trigger, - memo_trigger_hook, - ) in memo_event_triggers: - # Replace the event trigger with the memoized version. - memo_trigger_hooks.append(memo_trigger_hook) - component.event_triggers[event_trigger] = memo_trigger - - return memo_trigger_hooks - - @staticmethod - def _get_hook_deps(hook: str) -> list[str]: - """Extract var deps from a hook. - - Args: - hook: The hook line to extract deps from. - - Returns: - A list of var names created by the hook declaration. - """ - # Ensure that the hook is a var declaration. - var_decl = hook.partition("=")[0].strip() - if not any(var_decl.startswith(kw) for kw in ["const ", "let ", "var "]): - return [] - - # Extract the var name from the declaration. - _, _, var_name = var_decl.partition(" ") - var_name = var_name.strip() - - # Break up array and object destructuring if used. - if var_name.startswith(("[", "{")): - return [ - v.strip().replace("...", "") for v in var_name.strip("[]{}").split(",") - ] - return [var_name] - - @staticmethod - def _get_deps_from_event_trigger( - event: EventChain | EventSpec | Var, - ) -> dict[str, None]: - """Get the dependencies accessed by event triggers. - - Args: - event: The event trigger to extract deps from. - - Returns: - The dependencies accessed by the event triggers. - """ - events: list = [event] - deps = {} - - if isinstance(event, EventChain): - events.extend(event.events) - - for ev in events: - if isinstance(ev, EventSpec): - for arg in ev.args: - for a in arg: - var_datas = VarData.merge(a._get_all_var_data()) - if var_datas and var_datas.deps is not None: - deps |= {str(dep): None for dep in var_datas.deps} - return deps - - @classmethod - def _get_memoized_event_triggers( - cls, - component: Component, - ) -> dict[str, tuple[Var, str]]: - """Memoize event handler functions with useCallback to avoid unnecessary re-renders. - - Args: - component: The component with events to memoize. - - Returns: - A dict of event trigger name to a tuple of the memoized event trigger Var and - the hook code that memoizes the event handler. - """ - trigger_memo = {} - for event_trigger, event_args in component._get_vars_from_event_triggers( - component.event_triggers - ): - if event_trigger in { - EventTriggers.ON_MOUNT, - EventTriggers.ON_UNMOUNT, - EventTriggers.ON_SUBMIT, - }: - # Do not memoize lifecycle or submit events. - continue - - # Get the actual EventSpec and render it. - event = component.event_triggers[event_trigger] - rendered_chain = str(LiteralVar.create(event)) - - # Hash the rendered EventChain to get a deterministic function name. - chain_hash = md5(str(rendered_chain).encode("utf-8")).hexdigest() - memo_name = f"{event_trigger}_{chain_hash}" - - # Calculate Var dependencies accessed by the handler for useCallback dep array. - var_deps = ["addEvents", "ReflexEvent"] - - # Get deps from event trigger var data. - var_deps.extend(cls._get_deps_from_event_trigger(event)) - - # Get deps from hooks. - for arg in event_args: - var_data = arg._get_all_var_data() - if var_data is None: - continue - for hook in var_data.hooks: - var_deps.extend(cls._get_hook_deps(hook)) - memo_var_data = VarData.merge( - *[var._get_all_var_data() for var in event_args], - VarData( - imports={"react": [ImportVar(tag="useCallback")]}, - ), - ) - - # Store the memoized function name and hook code for this event trigger. - trigger_memo[event_trigger] = ( - Var(_js_expr=memo_name)._replace( - _var_type=EventChain, merge_var_data=memo_var_data - ), - f"const {memo_name} = useCallback({rendered_chain}, [{', '.join(var_deps)}])", - ) - return trigger_memo - - def _get_all_hooks_internal(self) -> dict[str, VarData | None]: - """Get the reflex internal hooks for the component and its children. - - Returns: - The code that should appear just before user-defined hooks. - """ - return {} - - def _get_all_hooks(self) -> dict[str, VarData | None]: - """Get the React hooks for this component. - - Returns: - The code that should appear just before returning the rendered component. - """ - return {} - - def _get_all_imports(self) -> ParsedImportDict: - """Get all the libraries and fields that are used by the component. - - Returns: - The import dict with the required imports. - """ - if self.rendered_as_shared: - return { - f"$/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}": [ - ImportVar(tag=self.tag) - ] - } - return self.component._get_all_imports() - - def _get_all_dynamic_imports(self) -> set[str]: - """Get dynamic imports for the component. - - Returns: - The dynamic imports. - """ - if self.rendered_as_shared: - return set() - return self.component._get_all_dynamic_imports() - - def _get_all_custom_code(self, export: bool = False) -> dict[str, None]: - """Get custom code for the component. - - Args: - export: Whether to export the component. - - Returns: - The custom code. - """ - if self.rendered_as_shared: - return {} - return self.component._get_all_custom_code() | ({ - self._render_stateful_code(export=export): None - }) - - def _get_all_refs(self) -> dict[str, None]: - """Get the refs for the children of the component. - - Returns: - The refs for the children. - """ - if self.rendered_as_shared: - return {} - return self.component._get_all_refs() - - def render(self) -> dict: - """Define how to render the component in React. - - Returns: - The tag to render. - """ - return dict(Tag(name=self.tag or "")) - - def __str__(self) -> str: - """Represent the component in React. - - Returns: - The code to render the component. - """ - from reflex.compiler.compiler import _compile_component - - return _compile_component(self) - - @classmethod - def compile_from(cls, component: BaseComponent) -> BaseComponent: - """Walk through the component tree and memoize all stateful components. - - Args: - component: The component to memoize. - - Returns: - The memoized component tree. - """ - if isinstance(component, Component): - if component._memoization_mode.recursive: - # Recursively memoize stateful children (default). - component.children = [ - cls.compile_from(child) for child in component.children - ] - # Memoize this component if it depends on state. - stateful_component = cls.create(component) - if stateful_component is not None: - return stateful_component - return component - - class MemoizationLeaf(Component): """A component that does not separately memoize its children. @@ -2823,30 +2452,15 @@ class MemoizationLeaf(Component): components within it, should be a memoization leaf so the compiler does not replace the provided child tags with memoized tags. - During creation, a memoization leaf will mark itself as wanting to be - memoized if any of its children return any hooks. + Whether the leaf is wrapped in a memo definition is decided by the + compiler's snapshot-boundary subtree scan, not by a class-local + disposition override — so leaves and components that explicitly set + ``_memoization_mode = MemoizationMode(recursive=False)`` are handled + identically. """ _memoization_mode = MemoizationMode(recursive=False) - @classmethod - def create(cls, *children, **props) -> Component: - """Create a new memoization leaf component. - - Args: - *children: The children of the component. - **props: The props of the component. - - Returns: - The memoization leaf - """ - comp = super().create(*children, **props) - if comp._get_all_hooks(): - comp._memoization_mode = dataclasses.replace( - comp._memoization_mode, disposition=MemoizationDisposition.ALWAYS - ) - return comp - load_dynamic_serializer() diff --git a/packages/reflex-base/src/reflex_base/components/dynamic.py b/packages/reflex-base/src/reflex_base/components/dynamic.py index 6c2100a40e8..a668bd341fc 100644 --- a/packages/reflex-base/src/reflex_base/components/dynamic.py +++ b/packages/reflex-base/src/reflex_base/components/dynamic.py @@ -26,14 +26,20 @@ def get_cdn_url(lib: str) -> str: return f"https://cdn.jsdelivr.net/npm/{lib}" + "/+esm" -bundled_libraries = [ +DEFAULT_BUNDLED_LIBRARIES = [ "react", - "@radix-ui/themes", "@emotion/react", f"$/{constants.Dirs.UTILS}/context", f"$/{constants.Dirs.UTILS}/state", f"$/{constants.Dirs.UTILS}/components", ] +bundled_libraries = list(DEFAULT_BUNDLED_LIBRARIES) + + +def reset_bundled_libraries() -> None: + """Reset the bundled library registry to its default values.""" + bundled_libraries.clear() + bundled_libraries.extend(DEFAULT_BUNDLED_LIBRARIES) def bundle_library(component: Union["Component", str]): @@ -46,7 +52,7 @@ def bundle_library(component: Union["Component", str]): DynamicComponentMissingLibraryError: Raised when a dynamic component is missing a library. """ if isinstance(component, str): - bundled_libraries.append(component) + bundled_libraries.append(format_library_name(component)) return if component.library is None: msg = "Component must have a library to bundle." @@ -85,9 +91,8 @@ def make_component(component: Component) -> str: rendered_components.update(component._get_all_custom_code()) rendered_components[ - templates.stateful_component_template( + templates.dynamic_component_template( tag_name="MySSRComponent", - memo_trigger_hooks=[], component=component, export=True, ) @@ -110,7 +115,7 @@ def make_component(component: Component) -> str: else: imports[lib] = names - module_code_lines = templates.stateful_components_template( + module_code_lines = templates.dynamic_components_module_template( imports=utils.compile_imports(imports), memoized_code="\n".join(rendered_components), ).splitlines() @@ -191,6 +196,7 @@ def evaluate_component(js_string: Var[str]) -> Var[Component]: imports.ImportVar(tag="evalReactComponent"), ], "react": [ + imports.ImportVar(tag="createElement"), imports.ImportVar(tag="useState"), imports.ImportVar(tag="useEffect"), ], @@ -202,7 +208,7 @@ def evaluate_component(js_string: Var[str]) -> Var[Component]: f"evalReactComponent({js_string!s})" ".then((component) => {" "if (isMounted) {" - f"set_{unique_var_name}(component);" + f"set_{unique_var_name}(() => createElement(component));" "}" "});" "return () => {" diff --git a/packages/reflex-base/src/reflex_base/components/memoize_helpers.py b/packages/reflex-base/src/reflex_base/components/memoize_helpers.py new file mode 100644 index 00000000000..9b74e881753 --- /dev/null +++ b/packages/reflex-base/src/reflex_base/components/memoize_helpers.py @@ -0,0 +1,295 @@ +"""Memoization helpers for auto-memoized and pseudo-stateful components. + +These helpers wrap a component's non-lifecycle event triggers in ``useCallback`` +so that React can skip re-renders of subtrees whose event handlers have stable +identities. They are used by both the compiler auto-memoization plugin (see +``reflex.compiler.plugins.memoize``) and by component-creation-time consumers +in ``reflex-components-core`` (e.g. ``WindowEventListener``, ``upload``). + +Auto-memoized components compile using one of two render strategies: + +- Passthrough memo bodies render the root component with a ``{children}`` hole. + The page still renders the descendants, which keeps root-level introspection + such as ``Form._get_form_refs`` working against the authored child tree. +- Snapshot memo bodies render the captured subtree in the memo module. This is + required for non-recursive memoization leaves and structural forms + (``Foreach``) whose stateful render logic belongs inside the memo component + rather than the containing page. +""" + +from __future__ import annotations + +import enum +from hashlib import md5 +from typing import TYPE_CHECKING + +from reflex_base.components.component import BaseComponent, Component +from reflex_base.constants import EventTriggers +from reflex_base.event import EventChain, EventSpec +from reflex_base.utils.imports import ImportVar +from reflex_base.vars import VarData +from reflex_base.vars.base import LiteralVar, Var +from reflex_base.vars.sequence import ArrayVar + +if TYPE_CHECKING: + from reflex_base.plugins.compiler import PageContext + + +class MemoizationStrategy(enum.Enum): + """How an auto-memo wrapper should render a component if it is memoized.""" + + PASSTHROUGH = "passthrough" + SNAPSHOT = "snapshot" + + +def _get_hook_deps(hook: str) -> list[str]: + """Extract Var deps from a hook declaration line. + + Args: + hook: The hook line (e.g. ``"const foo = useState(...)"``). + + Returns: + The names of variables created by the declaration. + """ + var_decl = hook.partition("=")[0].strip() + if not any(var_decl.startswith(kw) for kw in ["const ", "let ", "var "]): + return [] + _, _, var_name = var_decl.partition(" ") + var_name = var_name.strip() + if var_name.startswith(("[", "{")): + return [v.strip().replace("...", "") for v in var_name.strip("[]{}").split(",")] + return [var_name] + + +def _get_deps_from_event_trigger( + event: EventChain | EventSpec | Var, +) -> dict[str, None]: + """Get the dependencies accessed by an event trigger value. + + Args: + event: The event trigger value. + + Returns: + Dependency names, insertion-ordered. + """ + events: list = [event] + deps: dict[str, None] = {} + + if isinstance(event, EventChain): + events.extend(event.events) + + for ev in events: + if isinstance(ev, EventSpec): + for arg in ev.args: + for a in arg: + var_datas = VarData.merge(a._get_all_var_data()) + if var_datas and var_datas.deps is not None: + deps |= {str(dep): None for dep in var_datas.deps} + return deps + + +def get_memoized_event_triggers( + component: Component, +) -> dict[str, Var]: + """Generate ``useCallback`` wrappers for the component's event triggers. + + Args: + component: The component whose event triggers should be memoized. + + Returns: + A dict mapping event trigger name to memoized_triger. + """ + trigger_memo: dict[str, Var] = {} + for event_trigger, event_args in component._get_vars_from_event_triggers( + component.event_triggers + ): + if event_trigger in { + EventTriggers.ON_MOUNT, + EventTriggers.ON_UNMOUNT, + EventTriggers.ON_SUBMIT, + }: + # Do not memoize lifecycle or submit events. + continue + + event = component.event_triggers[event_trigger] + rendered_chain = LiteralVar.create(event) + + chain_hash = md5( + str(rendered_chain).encode("utf-8"), usedforsecurity=False + ).hexdigest() + memo_name = f"{event_trigger}_{chain_hash}" + + var_deps = ["addEvents", "ReflexEvent"] + var_deps.extend(_get_deps_from_event_trigger(event)) + + event_var_data = [] + for arg in event_args: + var_data = arg._get_all_var_data() + if var_data is None: + continue + event_var_data.append(var_data) + for hook in var_data.hooks: + var_deps.extend(_get_hook_deps(hook)) + + memo_var_data = VarData.merge( + *event_var_data, + rendered_chain._get_all_var_data(), + VarData( + hooks=[ + f"const {memo_name} = useCallback({rendered_chain!s}, [{', '.join(var_deps)}])" + ], + imports={"react": [ImportVar(tag="useCallback")]}, + ), + ) + + trigger_memo[event_trigger] = Var( + _js_expr=memo_name, _var_type=EventChain, _var_data=memo_var_data + ) + return trigger_memo + + +def fix_event_triggers_for_memo( + component: Component, page_context: PageContext +) -> Component: + """Return a component whose event triggers reference memoized ``useCallback``s. + + Replaces each (non-lifecycle) event-trigger value with a ``Var`` naming a + memoized ``useCallback`` wrapper. The original is never mutated — a + page-local clone is taken via ``page_context.own`` on first write. + + Args: + component: The component whose event triggers to memoize. + page_context: The active page context, used to obtain a page-local + clone before rewriting ``event_triggers``. + + Returns: + Either ``component`` (when nothing needed rewriting) or a page-local + clone with the rewritten ``event_triggers``. + """ + memo_event_triggers = tuple(get_memoized_event_triggers(component).items()) + if not memo_event_triggers: + return component + owned = page_context.own(component) + owned.event_triggers = { + **component.event_triggers, + **dict(memo_event_triggers), + } + return owned + + +def is_snapshot_boundary(component: Component) -> bool: + """Whether ``component`` owns its subtree for memoization purposes. + + Snapshot boundaries (``MemoizationLeaf``-style components with + ``_memoization_mode.recursive=False``) encapsulate internal machinery as + their own structural children. The auto-memoize compiler pass must wrap + them whole and not walk or independently memoize that subtree. + + The check is the behavioral flag, not ``isinstance(MemoizationLeaf)``, so + components that opt into non-recursive memoization without subclassing + ``MemoizationLeaf`` are handled identically. + + Args: + component: The component to classify. + + Returns: + ``True`` iff descendants of ``component`` must not be independently + memoized and the memo wrapper must carry the full subtree snapshot. + """ + return not component._memoization_mode.recursive + + +def _is_structural_memoization_child(component: Component) -> bool: + """Check whether ``component`` is a structural child for memoization. + + Args: + component: The child component to inspect. + + Returns: + True when the component's render body must stay inside the generated + memo body rather than flowing through as a normal ``children`` payload. + """ + from reflex_components_core.core.foreach import Foreach + + return bool(isinstance(component, Foreach)) + + +def _has_memoization_snapshot_child(component: Component) -> bool: + """Whether ``component`` has a structural child that needs a memo snapshot. + + Component-valued ``Foreach`` is a structural form, not an ordinary user + child. It must render its body in the memo module instead of flowing + through as a normal ``children`` payload. + + Args: + component: The component whose direct children should be inspected. + + Returns: + True when a direct child requires the parent memo wrapper to render a + captured snapshot. + """ + return any( + isinstance(child, Component) and _is_structural_memoization_child(child) + for child in component.children + ) + + +def passthrough_children_var( + children: list[BaseComponent], +) -> ArrayVar[list[BaseComponent]] | None: + """Return the placeholder ``children`` array Var if ``children`` is a memo hole. + + Auto-memo passthrough wrappers replace the wrapped component's children + with a single ``Bare(Var[Component](_js_expr="children"))`` placeholder + when compiling the memo body. Render-time consumers (notably ``Cond`` and + ``Match``) detect this and rewrite branches to index into the placeholder + array instead of capturing the original branch JSX in the memo body. The + returned Var is retyped to ``list[BaseComponent]`` so callers can index + it directly. + + Args: + children: The component's children list. + + Returns: + The placeholder Var (retyped as a Component list) for indexed access, + else ``None``. + """ + from reflex_components_core.base.bare import Bare + + if ( + len(children) == 1 + and isinstance(children[0], Bare) + and isinstance(children[0].contents, Var) + and children[0].contents._js_expr == "children" + ): + return children[0].contents.to(list[BaseComponent]) + return None + + +def get_memoization_strategy(component: Component) -> MemoizationStrategy: + """Get the render strategy for ``component`` if auto-memoization wraps it. + + Args: + component: The component being considered by auto-memoization. + + Returns: + The strategy to use when generating a memo wrapper. + """ + if ( + is_snapshot_boundary(component) + or _is_structural_memoization_child(component) + or _has_memoization_snapshot_child(component) + ): + return MemoizationStrategy.SNAPSHOT + + return MemoizationStrategy.PASSTHROUGH + + +__all__ = [ + "MemoizationStrategy", + "fix_event_triggers_for_memo", + "get_memoization_strategy", + "get_memoized_event_triggers", + "is_snapshot_boundary", + "passthrough_children_var", +] diff --git a/packages/reflex-base/src/reflex_base/environment.py b/packages/reflex-base/src/reflex_base/environment.py index 31ebe795998..f2d4ce44cb9 100644 --- a/packages/reflex-base/src/reflex_base/environment.py +++ b/packages/reflex-base/src/reflex_base/environment.py @@ -2,14 +2,11 @@ from __future__ import annotations -import concurrent.futures import dataclasses import enum import importlib -import multiprocessing import os -import platform -from collections.abc import Callable, Sequence +from collections.abc import Sequence from functools import lru_cache from pathlib import Path from typing import ( @@ -529,97 +526,6 @@ class PerformanceMode(enum.Enum): OFF = "off" -class ExecutorType(enum.Enum): - """Executor for compiling the frontend.""" - - THREAD = "thread" - PROCESS = "process" - MAIN_THREAD = "main_thread" - - @classmethod - def get_executor_from_environment(cls): - """Get the executor based on the environment variables. - - Returns: - The executor. - """ - from reflex_base.utils import console - - executor_type = environment.REFLEX_COMPILE_EXECUTOR.get() - - reflex_compile_processes = environment.REFLEX_COMPILE_PROCESSES.get() - reflex_compile_threads = environment.REFLEX_COMPILE_THREADS.get() - # By default, use the main thread. Unless the user has specified a different executor. - # Using a process pool is much faster, but not supported on all platforms. It's gated behind a flag. - if executor_type is None: - if ( - platform.system() not in ("Linux", "Darwin") - and reflex_compile_processes is not None - ): - console.warn("Multiprocessing is only supported on Linux and MacOS.") - - if ( - platform.system() in ("Linux", "Darwin") - and reflex_compile_processes is not None - ): - if reflex_compile_processes == 0: - console.warn( - "Number of processes must be greater than 0. If you want to use the default number of processes, set REFLEX_COMPILE_EXECUTOR to 'process'. Defaulting to None." - ) - reflex_compile_processes = None - elif reflex_compile_processes < 0: - console.warn( - "Number of processes must be greater than 0. Defaulting to None." - ) - reflex_compile_processes = None - executor_type = ExecutorType.PROCESS - elif reflex_compile_threads is not None: - if reflex_compile_threads == 0: - console.warn( - "Number of threads must be greater than 0. If you want to use the default number of threads, set REFLEX_COMPILE_EXECUTOR to 'thread'. Defaulting to None." - ) - reflex_compile_threads = None - elif reflex_compile_threads < 0: - console.warn( - "Number of threads must be greater than 0. Defaulting to None." - ) - reflex_compile_threads = None - executor_type = ExecutorType.THREAD - else: - executor_type = ExecutorType.MAIN_THREAD - - match executor_type: - case ExecutorType.PROCESS: - executor = concurrent.futures.ProcessPoolExecutor( - max_workers=reflex_compile_processes, - mp_context=multiprocessing.get_context("fork"), - ) - case ExecutorType.THREAD: - executor = concurrent.futures.ThreadPoolExecutor( - max_workers=reflex_compile_threads - ) - case ExecutorType.MAIN_THREAD: - FUTURE_RESULT_TYPE = TypeVar("FUTURE_RESULT_TYPE") - - class MainThreadExecutor: - def __enter__(self): - return self - - def __exit__(self, *args): - pass - - def submit( - self, fn: Callable[..., FUTURE_RESULT_TYPE], *args, **kwargs - ) -> concurrent.futures.Future[FUTURE_RESULT_TYPE]: - future_job = concurrent.futures.Future() - future_job.set_result(fn(*args, **kwargs)) - return future_job - - executor = MainThreadExecutor() - - return executor - - class EnvironmentVariables: """Environment variables class to instantiate environment variables.""" @@ -660,14 +566,6 @@ class EnvironmentVariables: Path(constants.Dirs.UPLOADED_FILES) ) - REFLEX_COMPILE_EXECUTOR: EnvVar[ExecutorType | None] = env_var(None) - - # Whether to use separate processes to compile the frontend and how many. If not set, defaults to thread executor. - REFLEX_COMPILE_PROCESSES: EnvVar[int | None] = env_var(None) - - # Whether to use separate threads to compile the frontend and how many. Defaults to `min(32, os.cpu_count() + 4)`. - REFLEX_COMPILE_THREADS: EnvVar[int | None] = env_var(None) - # The directory to store reflex dependencies. REFLEX_DIR: EnvVar[Path] = env_var(constants.Reflex.DIR) diff --git a/packages/reflex-base/src/reflex_base/plugins/__init__.py b/packages/reflex-base/src/reflex_base/plugins/__init__.py index b4320489b08..f3ef5aa971c 100644 --- a/packages/reflex-base/src/reflex_base/plugins/__init__.py +++ b/packages/reflex-base/src/reflex_base/plugins/__init__.py @@ -3,12 +3,26 @@ from . import sitemap, tailwind_v3, tailwind_v4 from ._screenshot import ScreenshotPlugin as _ScreenshotPlugin from .base import CommonContext, Plugin, PreCompileContext +from .compiler import ( + BaseContext, + CompileContext, + CompilerHooks, + ComponentAndChildren, + PageContext, + PageDefinition, +) from .sitemap import SitemapPlugin from .tailwind_v3 import TailwindV3Plugin from .tailwind_v4 import TailwindV4Plugin __all__ = [ + "BaseContext", "CommonContext", + "CompileContext", + "CompilerHooks", + "ComponentAndChildren", + "PageContext", + "PageDefinition", "Plugin", "PreCompileContext", "SitemapPlugin", diff --git a/packages/reflex-base/src/reflex_base/plugins/base.py b/packages/reflex-base/src/reflex_base/plugins/base.py index 52dfa8d7805..082258ddb9c 100644 --- a/packages/reflex-base/src/reflex_base/plugins/base.py +++ b/packages/reflex-base/src/reflex_base/plugins/base.py @@ -1,13 +1,25 @@ """Base class for all plugins.""" from collections.abc import Callable, Sequence +from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, ParamSpec, Protocol, TypedDict +from typing import TYPE_CHECKING, Any, ClassVar, ParamSpec, Protocol, TypedDict from typing_extensions import Unpack + +class HookOrder(str, Enum): + """Dispatch bucket for a compiler ``enter_component`` / ``leave_component`` hook.""" + + PRE = "pre" + NORMAL = "normal" + POST = "post" + + if TYPE_CHECKING: from reflex.app import App, UnevaluatedPage + from reflex_base.components.component import BaseComponent + from reflex_base.plugins.compiler import ComponentAndChildren, PageContext class CommonContext(TypedDict): @@ -41,6 +53,7 @@ class PreCompileContext(CommonContext): add_save_task: AddTaskProtocol add_modify_task: Callable[[str, Callable[[str], str]], None] + radix_themes_plugin: Any unevaluated_pages: Sequence["UnevaluatedPage"] @@ -53,6 +66,13 @@ class PostCompileContext(CommonContext): class Plugin: """Base class for all plugins.""" + # Dispatch position for ``enter_component`` and ``leave_component`` hooks. + # Plugins run in ``PRE`` → ``NORMAL`` → ``POST`` order. Within a bucket, + # enter hooks fire in plugin-chain order while leave hooks fire in + # reverse plugin-chain order (mirroring an enter/leave stack). + _compiler_enter_component_order: ClassVar[HookOrder] = HookOrder.NORMAL + _compiler_leave_component_order: ClassVar[HookOrder] = HookOrder.NORMAL + def get_frontend_development_dependencies( self, **context: Unpack[CommonContext] ) -> list[str] | set[str] | tuple[str, ...]: @@ -117,6 +137,78 @@ def post_compile(self, **context: Unpack[PostCompileContext]) -> None: context: The context for the plugin. """ + def eval_page( + self, + page_fn: Any, + /, + **kwargs: Any, + ) -> "PageContext | None": + """Evaluate a page-like object into a page context. + + Args: + page_fn: The page-like object to evaluate. + kwargs: Additional compiler-specific context. + + Returns: + A page context when the plugin can evaluate the page, otherwise ``None``. + """ + return None + + def compile_page( + self, + page_ctx: "PageContext", + /, + **kwargs: Any, + ) -> None: + """Finalize a page context after its component tree has been traversed.""" + return + + def enter_component( + self, + comp: "BaseComponent", + /, + *, + page_context: "PageContext", + compile_context: Any, + in_prop_tree: bool = False, + ) -> "BaseComponent | ComponentAndChildren | None": + """Inspect or transform a component before visiting its descendants. + + Args: + comp: The component being compiled. + page_context: The active page compilation state. + compile_context: The active compile-run state. + in_prop_tree: Whether the component is being visited through a prop subtree. + + Returns: + An optional replacement component and/or structural children. + """ + return None + + def leave_component( + self, + comp: "BaseComponent", + children: tuple["BaseComponent", ...], + /, + *, + page_context: "PageContext", + compile_context: Any, + in_prop_tree: bool = False, + ) -> "BaseComponent | ComponentAndChildren | None": + """Inspect or transform a component after visiting its descendants. + + Args: + comp: The component being compiled. + children: The compiled structural children for the component. + page_context: The active page compilation state. + compile_context: The active compile-run state. + in_prop_tree: Whether the component is being visited through a prop subtree. + + Returns: + An optional replacement component and/or structural children. + """ + return None + def __repr__(self): """Return a string representation of the plugin. diff --git a/packages/reflex-base/src/reflex_base/plugins/compiler.py b/packages/reflex-base/src/reflex_base/plugins/compiler.py new file mode 100644 index 00000000000..ecb55a03d92 --- /dev/null +++ b/packages/reflex-base/src/reflex_base/plugins/compiler.py @@ -0,0 +1,869 @@ +"""Compiler plugin infrastructure: protocols, contexts, and dispatch.""" + +from __future__ import annotations + +import copy +import dataclasses +import inspect +from collections.abc import Callable, Sequence +from contextvars import ContextVar, Token +from types import TracebackType +from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, TypeVar, cast + +from typing_extensions import Self + +from reflex_base.components.component import BaseComponent, Component +from reflex_base.utils.imports import ParsedImportDict, collapse_imports, merge_imports +from reflex_base.vars import VarData + +from .base import HookOrder, Plugin + +if TYPE_CHECKING: + from reflex.app import App, ComponentCallable + + PageComponent: TypeAlias = Component | ComponentCallable +else: + PageComponent: TypeAlias = ( + Component + | Callable[ + [], + Component | tuple[Component, ...] | str, + ] + ) + + +_BaseComponentT = TypeVar("_BaseComponentT", bound=BaseComponent) + + +class PageDefinition(Protocol): + """Protocol for page-like objects compiled by :class:`CompileContext`.""" + + @property + def route(self) -> str: + """Return the route for this page definition.""" + ... + + @property + def component(self) -> PageComponent: + """Return the component or callable for this page definition.""" + ... + + +ComponentAndChildren: TypeAlias = tuple[BaseComponent, tuple[BaseComponent, ...]] +ComponentReplacement: TypeAlias = BaseComponent | ComponentAndChildren | None +CompiledEnterHook: TypeAlias = Callable[ + [BaseComponent, bool], + ComponentReplacement, +] +CompiledLeaveHook: TypeAlias = Callable[ + [BaseComponent, tuple[BaseComponent, ...], bool], + ComponentReplacement, +] +EnterHookBinder: TypeAlias = Callable[ + ["PageContext", "CompileContext"], + CompiledEnterHook, +] +LeaveHookBinder: TypeAlias = Callable[ + ["PageContext", "CompileContext"], + CompiledLeaveHook, +] + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class CompilerHooks: + """Dispatch compiler hooks across an ordered plugin chain.""" + + plugins: tuple[Plugin, ...] = () + _eval_page_hooks: tuple[Callable[..., Any], ...] = dataclasses.field( + init=False, + repr=False, + ) + _compile_page_hooks: tuple[Callable[..., Any], ...] = dataclasses.field( + init=False, + repr=False, + ) + _enter_component_hook_binders: tuple[EnterHookBinder, ...] = dataclasses.field( + init=False, + repr=False, + ) + _leave_component_hook_binders: tuple[LeaveHookBinder, ...] = dataclasses.field( + init=False, + repr=False, + ) + _component_hooks_can_replace: bool = dataclasses.field( + init=False, + repr=False, + ) + + def __post_init__(self) -> None: + """Resolve the active compiler hook callables once.""" + object.__setattr__(self, "_eval_page_hooks", self._resolve_hooks("eval_page")) + object.__setattr__( + self, + "_compile_page_hooks", + self._resolve_hooks("compile_page"), + ) + enter_buckets: dict[HookOrder, list[EnterHookBinder]] = { + order: [] for order in HookOrder + } + leave_buckets: dict[HookOrder, list[LeaveHookBinder]] = { + order: [] for order in HookOrder + } + component_hooks_can_replace = False + + for plugin in self.plugins: + plugin_type = type(plugin) + if ( + hook_impl := self._get_hook_impl(plugin, "enter_component") + ) is not None: + enter_buckets[plugin_type._compiler_enter_component_order].append( + self._get_enter_hook_binder(plugin, hook_impl) + ) + component_hooks_can_replace = component_hooks_can_replace or bool( + getattr(plugin_type, "_compiler_can_replace_enter_component", True) + ) + + if ( + hook_impl := self._get_hook_impl(plugin, "leave_component") + ) is not None: + leave_buckets[plugin_type._compiler_leave_component_order].append( + self._get_leave_hook_binder(plugin, hook_impl) + ) + component_hooks_can_replace = component_hooks_can_replace or bool( + getattr(plugin_type, "_compiler_can_replace_leave_component", True) + ) + + object.__setattr__( + self, + "_enter_component_hook_binders", + tuple(binder for order in HookOrder for binder in enter_buckets[order]), + ) + object.__setattr__( + self, + "_leave_component_hook_binders", + tuple( + binder + for order in HookOrder + for binder in reversed(leave_buckets[order]) + ), + ) + object.__setattr__( + self, + "_component_hooks_can_replace", + component_hooks_can_replace, + ) + + @staticmethod + def _get_hook_impl( + plugin: Plugin, + hook_name: str, + ) -> Callable[..., Any] | None: + """Return the concrete hook implementation for a plugin, if any. + + Args: + plugin: The plugin to inspect. + hook_name: The hook attribute name. + + Returns: + The bound hook implementation, or ``None`` when the hook is inherited + unchanged from the default base implementation. + """ + plugin_impl = inspect.getattr_static(type(plugin), hook_name, None) + if plugin_impl is None: + return None + + if plugin_impl is inspect.getattr_static(Plugin, hook_name, None): + return None + + return cast(Callable[..., Any], getattr(plugin, hook_name, None)) + + def _resolve_hooks(self, hook_name: str) -> tuple[Callable[..., Any], ...]: + """Resolve concrete hook implementations for the plugin chain. + + Args: + hook_name: The hook attribute name. + + Returns: + The ordered concrete hook implementations for the hook. + """ + return tuple( + hook_impl + for plugin in self.plugins + if (hook_impl := self._get_hook_impl(plugin, hook_name)) is not None + ) + + @staticmethod + def _get_enter_hook_binder( + plugin: Plugin, + hook_impl: Callable[..., Any], + ) -> EnterHookBinder: + """Return a binder that produces a compiled enter-component hook.""" + if ( + binder := getattr(plugin, "_compiler_bind_enter_component", None) + ) is not None: + return cast(EnterHookBinder, binder) + + def bind( + page_context: PageContext, compile_context: CompileContext + ) -> CompiledEnterHook: + def enter_component( + comp: BaseComponent, + in_prop_tree: bool, + ) -> ComponentReplacement: + return cast( + ComponentReplacement, + hook_impl( + comp, + page_context=page_context, + compile_context=compile_context, + in_prop_tree=in_prop_tree, + ), + ) + + return enter_component + + return bind + + @staticmethod + def _get_leave_hook_binder( + plugin: Plugin, + hook_impl: Callable[..., Any], + ) -> LeaveHookBinder: + """Return a binder that produces a compiled leave-component hook.""" + if ( + binder := getattr(plugin, "_compiler_bind_leave_component", None) + ) is not None: + return cast(LeaveHookBinder, binder) + + def bind( + page_context: PageContext, compile_context: CompileContext + ) -> CompiledLeaveHook: + def leave_component( + comp: BaseComponent, + children: tuple[BaseComponent, ...], + in_prop_tree: bool, + ) -> ComponentReplacement: + return cast( + ComponentReplacement, + hook_impl( + comp, + children, + page_context=page_context, + compile_context=compile_context, + in_prop_tree=in_prop_tree, + ), + ) + + return leave_component + + return bind + + def eval_page( + self, + page_fn: PageComponent, + /, + *, + page: PageDefinition, + **kwargs: Any, + ) -> PageContext | None: + """Return the first page context produced by the plugin chain.""" + for hook_impl in self._eval_page_hooks: + result = hook_impl(page_fn, page=page, **kwargs) + if result is not None: + return cast(PageContext, result) + return None + + def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + """Run all ``compile_page`` hooks in plugin order.""" + for hook_impl in self._compile_page_hooks: + hook_impl(page_ctx, **kwargs) + + def compile_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> BaseComponent: + """Walk a component tree once while dispatching cached enter/leave hooks. + + Returns: + The compiled component root for this subtree. + """ + enter_hooks = tuple( + hook_binder(page_context, compile_context) + for hook_binder in self._enter_component_hook_binders + ) + + if not self._component_hooks_can_replace: + leave_hooks = tuple( + hook_binder(page_context, compile_context) + for hook_binder in self._leave_component_hook_binders + ) + + if len(enter_hooks) == 1 and not leave_hooks: + return self._compile_component_single_enter_fast_path( + comp, + enter_hook=enter_hooks[0], + page_context=page_context, + in_prop_tree=in_prop_tree, + ) + + return self._compile_component_without_replacements( + comp, + enter_hooks=enter_hooks, + leave_hooks=leave_hooks, + page_context=page_context, + in_prop_tree=in_prop_tree, + ) + + return self._compile_component_with_replacements( + comp, + enter_hooks=enter_hooks, + leave_hooks=tuple( + hook_binder(page_context, compile_context) + for hook_binder in self._leave_component_hook_binders + ), + page_context=page_context, + in_prop_tree=in_prop_tree, + ) + + def _compile_component_without_replacements( + self, + comp: BaseComponent, + /, + *, + enter_hooks: tuple[CompiledEnterHook, ...], + leave_hooks: tuple[CompiledLeaveHook, ...], + page_context: PageContext, + in_prop_tree: bool = False, + ) -> BaseComponent: + """Walk a component tree when hook plans only observe state. + + Returns: + The compiled component root for this subtree. + """ + + def visit( + current_comp: BaseComponent, + current_in_prop_tree: bool, + ) -> BaseComponent: + for hook_impl in enter_hooks: + hook_impl( + current_comp, + current_in_prop_tree, + ) + + updated_children: list[BaseComponent] | None = None + children = current_comp.children + for index, child in enumerate(children): + compiled_child = visit( + child, + current_in_prop_tree, + ) + if updated_children is None: + if compiled_child is child: + continue + updated_children = list(children[:index]) + updated_children.append(compiled_child) + if updated_children is not None: + current_comp = page_context.own(current_comp) + current_comp.children = updated_children + + if isinstance(current_comp, Component): + for prop_component in current_comp._get_components_in_props(): + visit( + prop_component, + True, + ) + + if leave_hooks: + compiled_children = tuple(current_comp.children) + for hook_impl in leave_hooks: + hook_impl( + current_comp, + compiled_children, + current_in_prop_tree, + ) + + return current_comp + + return visit( + comp, + in_prop_tree, + ) + + def _compile_component_single_enter_fast_path( + self, + comp: BaseComponent, + /, + *, + enter_hook: CompiledEnterHook, + page_context: PageContext, + in_prop_tree: bool = False, + ) -> BaseComponent: + """Walk a component tree for the common one-enter-hook fast path. + + Returns: + The compiled component root for this subtree. + """ + + def visit( + current_comp: BaseComponent, + current_in_prop_tree: bool, + ) -> BaseComponent: + enter_hook( + current_comp, + current_in_prop_tree, + ) + + updated_children: list[BaseComponent] | None = None + children = current_comp.children + for index, child in enumerate(children): + compiled_child = visit( + child, + current_in_prop_tree, + ) + if updated_children is None: + if compiled_child is child: + continue + updated_children = list(children[:index]) + updated_children.append(compiled_child) + if updated_children is not None: + current_comp = page_context.own(current_comp) + current_comp.children = updated_children + + if isinstance(current_comp, Component): + for prop_component in current_comp._get_components_in_props(): + visit( + prop_component, + True, + ) + + return current_comp + + return visit( + comp, + in_prop_tree, + ) + + def _compile_component_with_replacements( + self, + comp: BaseComponent, + /, + *, + enter_hooks: tuple[CompiledEnterHook, ...], + leave_hooks: tuple[CompiledLeaveHook, ...], + page_context: PageContext, + in_prop_tree: bool = False, + ) -> BaseComponent: + """Walk a component tree while honoring hook replacements. + + Returns: + The compiled component root for this subtree. + """ + apply_replacement = self._apply_replacement + + def visit_children( + children: Sequence[BaseComponent], + current_in_prop_tree: bool, + ) -> tuple[BaseComponent, ...]: + if not children: + return () + + updated_children: list[BaseComponent] | None = None + for index, child in enumerate(children): + compiled_child = visit( + child, + current_in_prop_tree, + ) + if updated_children is None: + if compiled_child is child: + continue + updated_children = list(children[:index]) + updated_children.append(compiled_child) + if updated_children is None: + return children if isinstance(children, tuple) else tuple(children) + return tuple(updated_children) + + def visit( + current_comp: BaseComponent, + current_in_prop_tree: bool, + ) -> BaseComponent: + compiled_component = current_comp + structural_children: tuple[BaseComponent, ...] | None = None + + for hook_impl in enter_hooks: + compiled_component, structural_children = apply_replacement( + compiled_component, + structural_children, + hook_impl( + compiled_component, + current_in_prop_tree, + ), + ) + + if structural_children is None: + structural_children = tuple(compiled_component.children) + compiled_children = visit_children( + structural_children, + current_in_prop_tree, + ) + if isinstance(compiled_component, Component): + for prop_component in compiled_component._get_components_in_props(): + visit( + prop_component, + True, + ) + + for hook_impl in leave_hooks: + compiled_component, replacement_children = apply_replacement( + compiled_component, + compiled_children, + hook_impl( + compiled_component, + compiled_children, + current_in_prop_tree, + ), + ) + if replacement_children is not compiled_children: + assert replacement_children is not None + # Re-walking fires enter/leave again on any child objects + # carried over from the original children tuple. Observing + # collectors dedupe by dict key, so this is idempotent for + # today's plugins; stateful side effects on the page + # context would be double-applied. + compiled_children = visit_children( + replacement_children, + current_in_prop_tree, + ) + + current = compiled_component.children + if len(compiled_children) != len(current) or any( + a is not b for a, b in zip(compiled_children, current, strict=True) + ): + compiled_component = page_context.own(compiled_component) + compiled_component.children = list(compiled_children) + return compiled_component + + return visit( + comp, + in_prop_tree, + ) + + @staticmethod + def _apply_replacement( + comp: BaseComponent, + children: tuple[BaseComponent, ...] | None, + replacement: ComponentReplacement, + ) -> tuple[BaseComponent, tuple[BaseComponent, ...] | None]: + """Apply a plugin replacement to the current component state. + + Args: + comp: The current component. + children: The current structural children. + replacement: The plugin-supplied replacement. + + Returns: + The updated component and structural children pair. + """ + if replacement is None: + return comp, children + if isinstance(replacement, tuple): + return replacement + return replacement, children + + +@dataclasses.dataclass(kw_only=True) +class BaseContext: + """Context manager that exposes itself through a class-local context var.""" + + __context_var__: ClassVar[ContextVar[Self | None]] + + _attached_context_token: Token[Self | None] | None = dataclasses.field( + default=None, + init=False, + repr=False, + ) + + @classmethod + def __init_subclass__(cls, **kwargs: Any) -> None: + """Initialize a dedicated context variable for each subclass.""" + super().__init_subclass__(**kwargs) + cls.__context_var__ = ContextVar(cls.__name__, default=None) + + @classmethod + def get(cls) -> Self: + """Return the active context instance for the current task. + + Returns: + The active context instance for the current task. + """ + context = cls.__context_var__.get() + if context is None: + msg = f"No active {cls.__name__} is attached to the current context." + raise RuntimeError(msg) + return context + + def __enter__(self) -> Self: + """Attach this context to the current task. + + Returns: + The attached context instance. + """ + if self._attached_context_token is not None: + msg = "Context is already attached and cannot be entered twice." + raise RuntimeError(msg) + self._attached_context_token = type(self).__context_var__.set(self) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Detach this context from the current task.""" + del exc_type, exc_val, exc_tb + if self._attached_context_token is None: + return + try: + type(self).__context_var__.reset(self._attached_context_token) + finally: + self._attached_context_token = None + + async def __aenter__(self) -> Self: + """Attach this context to the current task asynchronously. + + Returns: + The attached context instance. + """ + return self.__enter__() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Detach this context from the current task asynchronously.""" + self.__exit__(exc_type, exc_val, exc_tb) + + def ensure_context_attached(self) -> None: + """Ensure this instance is the active context for the current task.""" + try: + current = type(self).get() + except RuntimeError as err: + msg = ( + f"{type(self).__name__} must be entered with 'with' or 'async with' " + "before calling this method." + ) + raise RuntimeError(msg) from err + if current is not self: + msg = f"{type(self).__name__} is not attached to the current task context." + raise RuntimeError(msg) + + +@dataclasses.dataclass(slots=True, kw_only=True) +class PageContext(BaseContext): + """Mutable compilation state for a single page.""" + + name: str + route: str + root_component: BaseComponent + imports: list[ParsedImportDict] = dataclasses.field(default_factory=list) + module_code: dict[str, None] = dataclasses.field(default_factory=dict) + hooks: dict[str, VarData | None] = dataclasses.field(default_factory=dict) + dynamic_imports: set[str] = dataclasses.field(default_factory=set) + refs: dict[str, None] = dataclasses.field(default_factory=dict) + app_wrap_components: dict[tuple[int, str], Component] = dataclasses.field( + default_factory=dict + ) + frontend_imports: ParsedImportDict = dataclasses.field(default_factory=dict) + output_path: str | None = None + output_code: str | None = None + # Stack of ``id(component)`` for components whose subtree is + # memoize-suppressed. Populated by ``MemoizeStatefulPlugin`` when it + # encounters a ``MemoizationLeaf``-style snapshot boundary and popped on + # the matching ``leave_component``. Non-empty iff we are inside such a + # subtree. + memoize_suppressor_stack: list[int] = dataclasses.field(default_factory=list) + # Maps both the user-owned original's ``id()`` and the clone's ``id()`` to + # the page-local clone. Lets the walker and plugins rebind children, style, + # or event_triggers on a page-local copy without mutating a user-owned + # instance that may be referenced from another route. + _owned: dict[int, BaseComponent] = dataclasses.field(default_factory=dict) + # Strong references to originals keyed by ``id()`` above. Without these, + # an original that is only reachable through ``_owned``'s int key can be + # garbage collected, and Python may recycle its ``id()`` for a fresh + # component, causing ``own()`` to hand back the wrong clone. + _owned_refs: list[BaseComponent] = dataclasses.field(default_factory=list) + + def own(self, comp: _BaseComponentT) -> _BaseComponentT: + """Return a page-local copy of ``comp``, cloning on first encounter. + + Repeated calls with the same original return the same clone, so + mutations from several plugins accumulate on one instance. + + Args: + comp: The component the caller is about to mutate. + + Returns: + A component the caller may freely mutate without touching any + user-owned instance. + """ + existing = self._owned.get(id(comp)) + if existing is not None: + return cast("_BaseComponentT", existing) + new = copy.copy(comp) + self._owned[id(comp)] = new + self._owned[id(new)] = new + self._owned_refs.append(comp) + return new + + def merged_imports(self, *, collapse: bool = False) -> ParsedImportDict: + """Return the imports accumulated for this page. + + Args: + collapse: Whether to collapse duplicate imports. + + Returns: + The merged page imports. + """ + imports = merge_imports(*self.imports) if self.imports else {} + return collapse_imports(imports) if collapse else imports + + def custom_code_dict(self) -> dict[str, None]: + """Return custom-code snippets keyed like legacy collectors. + + Returns: + The page custom code keyed by snippet. + """ + return dict(self.module_code) + + +@dataclasses.dataclass(slots=True, kw_only=True) +class CompileContext(BaseContext): + """Mutable compilation state for an entire compile run.""" + + app: App | None = None + pages: Sequence[PageDefinition] + hooks: CompilerHooks = dataclasses.field(default_factory=CompilerHooks) + compiled_pages: dict[str, PageContext] = dataclasses.field(default_factory=dict) + all_imports: ParsedImportDict = dataclasses.field(default_factory=dict) + app_wrap_components: dict[tuple[int, str], Component] = dataclasses.field( + default_factory=dict + ) + stateful_routes: dict[str, None] = dataclasses.field(default_factory=dict) + # Auto-memoize wrapper tags seen during the tree walk (populated by + # ``MemoizeStatefulPlugin``). + memoize_wrappers: dict[str, None] = dataclasses.field(default_factory=dict) + # Compiler-generated experimental memo definitions for auto-memoized + # stateful wrappers. Stored as ``Any`` to keep ``reflex_base`` decoupled + # from ``reflex.experimental.memo``. + auto_memo_components: dict[str, Any] = dataclasses.field(default_factory=dict) + + def compile( + self, + *, + evaluate_progress: Callable[[], None] | None = None, + render_progress: Callable[[], None] | None = None, + **kwargs: Any, + ) -> dict[str, PageContext]: + """Compile all configured pages through the plugin pipeline. + + Args: + evaluate_progress: Callback invoked after each page evaluation. + render_progress: Callback invoked after each page render. + kwargs: Additional compiler-specific context. + + Returns: + The compiled page contexts keyed by route. + """ + from reflex.compiler import compiler + from reflex.state import all_base_state_classes + + self.ensure_context_attached() + self.compiled_pages.clear() + self.all_imports.clear() + self.app_wrap_components.clear() + self.stateful_routes.clear() + self.memoize_wrappers.clear() + self.auto_memo_components.clear() + + for page in self.pages: + page_fn = page.component + n_states_before = len(all_base_state_classes) + page_ctx = self.hooks.eval_page( + page_fn, + page=page, + compile_context=self, + **kwargs, + ) + if page_ctx is None: + page_name = getattr(page_fn, "__name__", repr(page_fn)) + msg = ( + f"No compiler plugin was able to evaluate page {page.route!r} " + f"({page_name})." + ) + raise RuntimeError(msg) + if page_ctx.route in self.compiled_pages: + msg = f"Duplicate compiled page route {page_ctx.route!r}." + raise RuntimeError(msg) + + if len(all_base_state_classes) > n_states_before: + self.stateful_routes[page.route] = None + + self.compiled_pages[page_ctx.route] = page_ctx + + if evaluate_progress is not None: + evaluate_progress() + + for page, page_ctx in zip( + self.pages, + self.compiled_pages.values(), + strict=True, + ): + with page_ctx: + page_ctx.root_component = self.hooks.compile_component( + page_ctx.root_component, + page_context=page_ctx, + compile_context=self, + ) + self.hooks.compile_page( + page_ctx, + page=page, + compile_context=self, + **kwargs, + ) + + page_ctx.frontend_imports = page_ctx.merged_imports(collapse=True) + self.all_imports = merge_imports( + self.all_imports, page_ctx.frontend_imports + ) + self.app_wrap_components.update(page_ctx.app_wrap_components) + page_ctx.output_path, page_ctx.output_code = ( + compiler.compile_page_from_context(page_ctx) + ) + + if render_progress is not None: + render_progress() + + return self.compiled_pages + + +__all__ = [ + "BaseContext", + "CompileContext", + "CompilerHooks", + "ComponentAndChildren", + "PageContext", + "PageDefinition", +] diff --git a/packages/reflex-base/src/reflex_base/plugins/tailwind_v3.py b/packages/reflex-base/src/reflex_base/plugins/tailwind_v3.py index d67264fc0e3..66f575db5c2 100644 --- a/packages/reflex-base/src/reflex_base/plugins/tailwind_v3.py +++ b/packages/reflex-base/src/reflex_base/plugins/tailwind_v3.py @@ -29,7 +29,7 @@ class Constants(SimpleNamespace): ROOT_STYLE_CONTENT = """ @import "tailwindcss/base"; -@import url('{radix_url}'); +{radix_import} @tailwind components; @tailwind utilities; @@ -54,9 +54,12 @@ def compile_config(config: TailwindConfig): ) -def compile_root_style(): +def compile_root_style(include_radix_themes: bool = True): """Compile the Tailwind root style. + Args: + include_radix_themes: Whether to include the Radix stylesheet import. + Returns: The compiled Tailwind root style. """ @@ -65,7 +68,9 @@ def compile_root_style(): return str( Path(Dirs.STYLES) / Constants.ROOT_STYLE_PATH ), Constants.ROOT_STYLE_CONTENT.format( - radix_url=RADIX_THEMES_STYLESHEET, + radix_import=( + f"@import url('{RADIX_THEMES_STYLESHEET}');" if include_radix_themes else "" + ), ) @@ -112,11 +117,14 @@ def add_tailwind_to_postcss_config(postcss_file_content: str) -> str: return "\n".join(postcss_file_lines) -def add_tailwind_to_css_file(css_file_content: str) -> str: +def add_tailwind_to_css_file( + css_file_content: str, *, include_radix_themes: bool = True +) -> str: """Add tailwind to the css file. Args: css_file_content: The content of the css file. + include_radix_themes: Whether the root stylesheet already imports Radix. Returns: The modified css file content. @@ -125,16 +133,23 @@ def add_tailwind_to_css_file(css_file_content: str) -> str: if Constants.TAILWIND_CSS.splitlines()[0] in css_file_content: return css_file_content - if RADIX_THEMES_STYLESHEET not in css_file_content: - print( # noqa: T201 - f"Could not find line with '{RADIX_THEMES_STYLESHEET}' in {Dirs.STYLES}. " - "Please make sure the file exists and is valid." + if include_radix_themes and RADIX_THEMES_STYLESHEET in css_file_content: + return css_file_content.replace( + f"@import url('{RADIX_THEMES_STYLESHEET}');", + Constants.TAILWIND_CSS, ) - return css_file_content - return css_file_content.replace( - f"@import url('{RADIX_THEMES_STYLESHEET}');", - Constants.TAILWIND_CSS, + + lines = css_file_content.splitlines() + insert_at = next( + ( + index + 1 + for index, line in enumerate(lines) + if "__reflex_style_reset.css" in line + ), + 1, ) + lines.insert(insert_at, Constants.TAILWIND_CSS) + return "\n".join(lines) @dataclasses.dataclass @@ -162,9 +177,14 @@ def pre_compile(self, **context): context: The context for the plugin. """ context["add_save_task"](compile_config, self.get_unversioned_config()) - context["add_save_task"](compile_root_style) + include_radix_themes = context["radix_themes_plugin"].enabled + + context["add_save_task"](compile_root_style, include_radix_themes) context["add_modify_task"](Dirs.POSTCSS_JS, add_tailwind_to_postcss_config) context["add_modify_task"]( str(Path(Dirs.STYLES) / (PageNames.STYLESHEET_ROOT + Ext.CSS)), - add_tailwind_to_css_file, + lambda content: add_tailwind_to_css_file( + content, + include_radix_themes=include_radix_themes, + ), ) diff --git a/packages/reflex-base/src/reflex_base/plugins/tailwind_v4.py b/packages/reflex-base/src/reflex_base/plugins/tailwind_v4.py index a3f1c24eb5b..7c11dc84c2a 100644 --- a/packages/reflex-base/src/reflex_base/plugins/tailwind_v4.py +++ b/packages/reflex-base/src/reflex_base/plugins/tailwind_v4.py @@ -29,8 +29,7 @@ class Constants(SimpleNamespace): ROOT_STYLE_CONTENT = """@layer theme, base, components, utilities; @import "tailwindcss/theme.css" layer(theme); @import "tailwindcss/preflight.css" layer(base); -@import "{radix_url}" layer(components); -@import "tailwindcss/utilities.css" layer(utilities); +{radix_import}@import "tailwindcss/utilities.css" layer(utilities); @config "../tailwind.config.js"; """ @@ -53,9 +52,12 @@ def compile_config(config: TailwindConfig): ) -def compile_root_style(): +def compile_root_style(include_radix_themes: bool = True): """Compile the Tailwind root style. + Args: + include_radix_themes: Whether to include the Radix stylesheet import. + Returns: The compiled Tailwind root style. """ @@ -64,7 +66,11 @@ def compile_root_style(): return str( Path(Dirs.STYLES) / Constants.ROOT_STYLE_PATH ), Constants.ROOT_STYLE_CONTENT.format( - radix_url=RADIX_THEMES_STYLESHEET, + radix_import=( + f'@import "{RADIX_THEMES_STYLESHEET}" layer(components);\n' + if include_radix_themes + else "" + ), ) @@ -115,11 +121,14 @@ def add_tailwind_to_postcss_config(postcss_file_content: str) -> str: return "\n".join(postcss_file_lines) -def add_tailwind_to_css_file(css_file_content: str) -> str: +def add_tailwind_to_css_file( + css_file_content: str, *, include_radix_themes: bool = True +) -> str: """Add tailwind to the css file. Args: css_file_content: The content of the css file. + include_radix_themes: Whether the root stylesheet already imports Radix. Returns: The modified css file content. @@ -128,16 +137,23 @@ def add_tailwind_to_css_file(css_file_content: str) -> str: if Constants.TAILWIND_CSS.splitlines()[0] in css_file_content: return css_file_content - if RADIX_THEMES_STYLESHEET not in css_file_content: - print( # noqa: T201 - f"Could not find line with '{RADIX_THEMES_STYLESHEET}' in {Dirs.STYLES}. " - "Please make sure the file exists and is valid." + if include_radix_themes and RADIX_THEMES_STYLESHEET in css_file_content: + return css_file_content.replace( + f"@import url('{RADIX_THEMES_STYLESHEET}');", + Constants.TAILWIND_CSS, ) - return css_file_content - return css_file_content.replace( - f"@import url('{RADIX_THEMES_STYLESHEET}');", - Constants.TAILWIND_CSS, + + lines = css_file_content.splitlines() + insert_at = next( + ( + index + 1 + for index, line in enumerate(lines) + if "__reflex_style_reset.css" in line + ), + 1, ) + lines.insert(insert_at, Constants.TAILWIND_CSS) + return "\n".join(lines) @dataclasses.dataclass @@ -166,9 +182,14 @@ def pre_compile(self, **context): context: The context for the plugin. """ context["add_save_task"](compile_config, self.get_unversioned_config()) - context["add_save_task"](compile_root_style) + include_radix_themes = context["radix_themes_plugin"].enabled + + context["add_save_task"](compile_root_style, include_radix_themes) context["add_modify_task"](Dirs.POSTCSS_JS, add_tailwind_to_postcss_config) context["add_modify_task"]( str(Path(Dirs.STYLES) / (PageNames.STYLESHEET_ROOT + Ext.CSS)), - add_tailwind_to_css_file, + lambda content: add_tailwind_to_css_file( + content, + include_radix_themes=include_radix_themes, + ), ) diff --git a/packages/reflex-base/src/reflex_base/registry.py b/packages/reflex-base/src/reflex_base/registry.py index 8caa1d2b2c3..71b4d723e5e 100644 --- a/packages/reflex-base/src/reflex_base/registry.py +++ b/packages/reflex-base/src/reflex_base/registry.py @@ -12,7 +12,6 @@ if TYPE_CHECKING: from reflex.state import BaseState - from reflex_base.components.component import StatefulComponent from reflex_base.event import EventHandler @@ -40,10 +39,6 @@ class RegistrationContext(BaseContext): default_factory=dict, repr=False, ) - tag_to_stateful_component: dict[str, StatefulComponent] = dataclasses.field( - default_factory=dict, - repr=False, - ) @classmethod def ensure_context(cls) -> Self: diff --git a/packages/reflex-base/src/reflex_base/utils/console.py b/packages/reflex-base/src/reflex_base/utils/console.py index b7a9539f77b..962008e42e0 100644 --- a/packages/reflex-base/src/reflex_base/utils/console.py +++ b/packages/reflex-base/src/reflex_base/utils/console.py @@ -479,6 +479,18 @@ def advance(self, task: TaskID, advance: int = 1): self.progress += advance _console.print(f"Progress: {self.progress}/{self.total}") + def update(self, task: TaskID, total: int | None = None): + """Update properties of a task. + + Args: + task: The task ID. + total: New total for the task. + """ + if total is not None and task in self.tasks: + previous_total = self.tasks[task]["total"] + self.tasks[task]["total"] = total + self.total += total - previous_total + def start(self): """Start the progress bar.""" diff --git a/packages/reflex-base/src/reflex_base/utils/streaming_response.py b/packages/reflex-base/src/reflex_base/utils/streaming_response.py index d9907379cef..66c6ab7fc62 100644 --- a/packages/reflex-base/src/reflex_base/utils/streaming_response.py +++ b/packages/reflex-base/src/reflex_base/utils/streaming_response.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import builtins import contextlib import sys @@ -60,14 +59,16 @@ def _collapse_excgroups() -> Generator[None, None, None]: class DisconnectAwareStreamingResponse(StreamingResponse): - """Streaming response that cancels its body task on disconnect.""" + """Streaming response with a guaranteed finish callback.""" _on_finish: Callable[[], Awaitable[None]] + _on_disconnect: Callable[[], None] | None def __init__( self, *args: Any, on_finish: Callable[[], Awaitable[None]], + on_disconnect: Callable[[], None] | None = None, **kwargs: Any, ) -> None: """Initialize the response. @@ -75,17 +76,17 @@ def __init__( Args: args: Positional args forwarded to ``StreamingResponse``. on_finish: Cleanup callback to run exactly once when the response ends. + on_disconnect: Sync callback invoked when the client disconnects. kwargs: Keyword args forwarded to ``StreamingResponse``. """ super().__init__(*args, **kwargs) self._on_finish = on_finish + self._on_disconnect = on_disconnect - async def _watch_disconnect(self, receive: Receive) -> None: - """Wait for the client connection to close.""" - while True: - message = await receive() - if message["type"] == "http.disconnect": - return + def _notify_disconnect(self) -> None: + """Invoke the on_disconnect callback if one was provided.""" + if self._on_disconnect is not None: + self._on_disconnect() async def _close_body_iterator(self) -> None: """Close the body iterator if it supports ``aclose``.""" @@ -94,7 +95,7 @@ async def _close_body_iterator(self) -> None: await aclose() async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """Serve the response and cancel the body task on disconnect.""" + """Serve the response and always run the finish callback.""" spec_version = _parse_asgi_spec_version(scope) try: @@ -107,47 +108,24 @@ async def wrap(func: Callable[[], Awaitable[None]]) -> None: task_group.cancel_scope.cancel() task_group.start_soon(wrap, partial(self.stream_response, send)) - await wrap(partial(self.listen_for_disconnect, receive)) - else: - # Verified against Starlette 0.52.1: the ASGI >= 2.4 path in - # StreamingResponse.__call__ delegates straight to - # stream_response(send) and does not read from receive(). - # Keep calling stream_response(send) directly here so the - # disconnect watcher remains the only receive() consumer; if - # Starlette changes that contract, re-check this logic. - stream_task = asyncio.create_task(self.stream_response(send)) - disconnect_task = asyncio.create_task(self._watch_disconnect(receive)) - should_close_body_iterator = False + if self._on_disconnect is not None: + + async def _disconnect_then_notify() -> None: + await self.listen_for_disconnect(receive) + self._notify_disconnect() + + await wrap(_disconnect_then_notify) + else: + await wrap(partial(self.listen_for_disconnect, receive)) + else: try: - done, _ = await asyncio.wait( - {stream_task, disconnect_task}, - return_when=asyncio.FIRST_COMPLETED, - ) - if disconnect_task in done and not stream_task.done(): - should_close_body_iterator = True - stream_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await stream_task - else: - try: - await stream_task - except OSError as err: - should_close_body_iterator = True - raise ClientDisconnect from err - finally: - if not disconnect_task.done(): - disconnect_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await disconnect_task - if not stream_task.done(): - should_close_body_iterator = True - stream_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await stream_task - if should_close_body_iterator: - await self._close_body_iterator() + await self.stream_response(send) + except OSError as err: + self._notify_disconnect() + raise ClientDisconnect from err finally: + await self._close_body_iterator() await self._on_finish() if self.background is not None: diff --git a/packages/reflex-components-core/src/reflex_components_core/base/link.py b/packages/reflex-components-core/src/reflex_components_core/base/link.py index a3a7b89dfbf..78bd301bcb2 100644 --- a/packages/reflex-components-core/src/reflex_components_core/base/link.py +++ b/packages/reflex-components-core/src/reflex_components_core/base/link.py @@ -3,10 +3,10 @@ from reflex_base.components.component import field from reflex_base.vars.base import Var -from reflex_components_core.el.elements.base import BaseHTML +from reflex_components_core.el.elements.base import RawTextBaseHTML, VoidBaseHTML -class RawLink(BaseHTML): +class RawLink(VoidBaseHTML): """A component that displays the title of the current page.""" tag = "link" @@ -16,7 +16,7 @@ class RawLink(BaseHTML): rel: Var[str] = field(doc="The type of link.") -class ScriptTag(BaseHTML): +class ScriptTag(RawTextBaseHTML): """A script tag with the specified type and source.""" tag = "script" diff --git a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py index 0674e071fbd..680fd7c613f 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py @@ -495,12 +495,22 @@ def _create_upload_event() -> Event: msg = "Upload event was not created." raise RuntimeError(msg) + disconnect_seen = False + + def _mark_disconnected() -> None: + nonlocal disconnect_seen + disconnect_seen = True + async def _ndjson_updates(): """Process the upload event, generating ndjson updates. Yields: Each state update as newline-delimited JSON. """ + # Let the disconnect watcher run before we enqueue the upload handler. + await asyncio.sleep(0) + if disconnect_seen: + return # Enqueue the task on the main event loop, but emit deltas to the local queue. async for delta in app.event_processor.enqueue_stream_delta(token, event): yield json_dumps(StateUpdate(delta=delta)) + "\n" @@ -509,6 +519,7 @@ async def _ndjson_updates(): _ndjson_updates(), media_type="application/x-ndjson", on_finish=_close_form_data, + on_disconnect=_mark_disconnected, ) diff --git a/packages/reflex-components-core/src/reflex_components_core/core/cond.py b/packages/reflex-components-core/src/reflex_components_core/core/cond.py index 98ba56d170d..a35443cb17c 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/cond.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/cond.py @@ -5,6 +5,7 @@ from typing import Any, TypeVar, overload from reflex_base.components.component import BaseComponent, Component, field +from reflex_base.components.memoize_helpers import passthrough_children_var from reflex_base.components.tags import CondTag, Tag from reflex_base.constants import Dirs from reflex_base.style import LIGHT_COLOR_MODE, resolved_color_mode @@ -14,6 +15,7 @@ from reflex_base.vars.base import LiteralVar, Var from reflex_base.vars.number import ternary_operation +from reflex_components_core.base.bare import Bare from reflex_components_core.base.fragment import Fragment _IS_TRUE_IMPORT: ImportDict = { @@ -26,6 +28,28 @@ class Cond(Component): cond: Var[Any] = field(doc="The cond to determine which component to render.") + def _get_cond_children(self) -> tuple[BaseComponent, BaseComponent]: + """Get true and false branch components with safe defaults. + + When rendering inside an auto-memo passthrough body, ``self.children`` + is collapsed to a single ``Bare`` holding the placeholder ``children`` + Var. The branches are reconstructed as indexed accesses + (``children[0]`` / ``children[1]``) so the page-side renders the real + branch JSX and the memo body just selects which one to mount. + + Returns: + A tuple containing true and false branch components. + """ + children_var = passthrough_children_var(self.children) + if children_var is not None: + return ( + Bare.create(children_var[0]), + Bare.create(children_var[1]), + ) + true_child = self.children[0] if self.children else Fragment.create() + false_child = self.children[1] if len(self.children) > 1 else Fragment.create() + return true_child, false_child + @classmethod def create( cls, @@ -60,10 +84,11 @@ def create( ) def _render(self) -> Tag: + true_child, false_child = self._get_cond_children() return CondTag( cond_state=str(self.cond), - true_value=self.children[0].render(), - false_value=self.children[1].render(), + true_value=true_child.render(), + false_value=false_child.render(), ) def render(self) -> dict: @@ -72,10 +97,11 @@ def render(self) -> dict: Returns: The dictionary for template of component. """ + true_child, false_child = self._get_cond_children() return { "cond_state": str(self.cond), - "true_value": self.children[0].render(), - "false_value": self.children[1].render(), + "true_value": true_child.render(), + "false_value": false_child.render(), } def add_imports(self) -> ImportDict: diff --git a/packages/reflex-components-core/src/reflex_components_core/core/debounce.py b/packages/reflex-components-core/src/reflex_components_core/core/debounce.py index 500cd1f96a6..b627e87eea5 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/debounce.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/debounce.py @@ -7,6 +7,7 @@ from reflex_base.components.component import Component, field from reflex_base.constants import EventTriggers from reflex_base.event import EventHandler, no_args_event_spec +from reflex_base.utils import format from reflex_base.vars import VarData from reflex_base.vars.base import Var @@ -106,7 +107,15 @@ def create(cls, *children: Component, **props: Any) -> Component: } props.setdefault("custom_attrs", {}).update(other_props, **child.custom_attrs) - # Carry base Component props. + # Carry base Component props. Drop any keys from child.style that + # collide with DebounceInput's own props — Reflex routes unknown child + # kwargs (e.g. ``debounce_timeout`` passed through ``rx.input``) into + # ``style``. + debounce_input_prop_names = { + format.to_camel_case(prop) for prop in cls.get_props() + } + for colliding_key in [k for k in child.style if k in debounce_input_prop_names]: + child.style.pop(colliding_key) props.setdefault("style", {}).update(child.style) if child.class_name is not None: props["class_name"] = f"{props.get('class_name', '')} {child.class_name}" diff --git a/packages/reflex-components-core/src/reflex_components_core/core/match.py b/packages/reflex-components-core/src/reflex_components_core/core/match.py index d062587b04c..9216cfb767b 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/match.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/match.py @@ -3,12 +3,8 @@ import textwrap from typing import Any, cast -from reflex_base.components.component import ( - BaseComponent, - Component, - MemoizationLeaf, - field, -) +from reflex_base.components.component import BaseComponent, Component, field +from reflex_base.components.memoize_helpers import passthrough_children_var from reflex_base.components.tags import Tag from reflex_base.components.tags.match_tag import MatchTag from reflex_base.style import Style @@ -19,9 +15,10 @@ from reflex_base.vars.base import LiteralVar, Var from reflex_components_core.base import Fragment +from reflex_components_core.base.bare import Bare -class Match(MemoizationLeaf): +class Match(Component): """Match cases based on a condition.""" cond: Var[Any] = field(doc="The condition to determine which case to match.") @@ -270,13 +267,39 @@ def _create_match_cond_var_or_component( ) def _render(self) -> Tag: + # Reconstruct match_cases and default from self.children, which may have + # been updated by the compiler walker to include memoized wrappers. + # self.children contains: [case_1_return, case_2_return, ..., default] + # self.match_cases contains the conditions as Vars. + num_cases = len(self.match_cases) + children_var = passthrough_children_var(self.children) + if children_var is not None: + # Auto-memo passthrough body: index into the placeholder array so + # branch JSX stays on the page side. + cases_returns = [Bare.create(children_var[i]) for i in range(num_cases)] + default_return = Bare.create(children_var[num_cases]) + else: + if len(self.children) != num_cases + 1: + msg = ( + f"Match children count mismatch: expected {num_cases + 1} " + f"(cases + default), got {len(self.children)}" + ) + raise ValueError(msg) + + cases_returns = self.children[:num_cases] + default_return = self.children[num_cases] + return MatchTag( cond=str(self.cond), match_cases=[ ([str(cond) for cond in conditions], return_value.render()) - for conditions, return_value in self.match_cases + for (conditions, _), return_value in zip( + self.match_cases, + cases_returns, + strict=True, + ) ], - default=self.default.render(), + default=default_return.render(), ) def render(self) -> dict: diff --git a/packages/reflex-components-core/src/reflex_components_core/core/upload.py b/packages/reflex-components-core/src/reflex_components_core/core/upload.py index 642ef851749..9bbb28de64d 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/upload.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/upload.py @@ -10,9 +10,9 @@ Component, ComponentNamespace, MemoizationLeaf, - StatefulComponent, field, ) +from reflex_base.components.memoize_helpers import get_memoized_event_triggers from reflex_base.constants import Dirs from reflex_base.constants.compiler import Hooks, Imports from reflex_base.environment import environment @@ -358,16 +358,13 @@ def create(cls, *children, **props) -> Component: ), ) - event_triggers = StatefulComponent._get_memoized_event_triggers( + event_triggers = get_memoized_event_triggers( GhostUpload.create( on_drop=upload_props["on_drop"], on_drop_rejected=upload_props["on_drop_rejected"], ) ) - callback_hooks = [] - for trigger_name, (event_var, callback_str) in event_triggers.items(): - upload_props[trigger_name] = event_var - callback_hooks.append(callback_str) + upload_props.update(event_triggers) upload_props = { format.to_camel_case(key): value for key, value in upload_props.items() @@ -392,7 +389,6 @@ def create(cls, *children, **props) -> Component: use_dropzone_arguments._get_all_var_data(), VarData( hooks={ - **dict.fromkeys(callback_hooks, None), f"{left_side} = {right_side};": None, }, imports={ diff --git a/packages/reflex-components-core/src/reflex_components_core/core/window_events.py b/packages/reflex-components-core/src/reflex_components_core/core/window_events.py index debb4c3dc37..ee57bb770c0 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/window_events.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/window_events.py @@ -4,7 +4,7 @@ from typing import Any, cast -from reflex_base.components.component import StatefulComponent, field +from reflex_base.components.component import field from reflex_base.constants.compiler import Hooks from reflex_base.event import EventHandler, key_event, no_args_event_spec from reflex_base.vars.base import Var, VarData @@ -95,9 +95,15 @@ def create(cls, **props) -> WindowEventListener: Returns: The created component. """ + from reflex_base.components.memoize_helpers import get_memoized_event_triggers + real_component = cast("WindowEventListener", super().create(**props)) - hooks = StatefulComponent._fix_event_triggers(real_component) - real_component.hooks = hooks + memo_event_triggers = get_memoized_event_triggers(real_component) + if memo_event_triggers: + real_component.event_triggers = { + **real_component.event_triggers, + **memo_event_triggers, + } return real_component def _exclude_props(self) -> list[str]: diff --git a/packages/reflex-components-core/src/reflex_components_core/el/elements/base.py b/packages/reflex-components-core/src/reflex_components_core/el/elements/base.py index e8b922f7430..476e3768edb 100644 --- a/packages/reflex-components-core/src/reflex_components_core/el/elements/base.py +++ b/packages/reflex-components-core/src/reflex_components_core/el/elements/base.py @@ -3,6 +3,7 @@ from typing import Literal from reflex_base.components.component import field +from reflex_base.constants.compiler import MemoizationMode from reflex_base.vars.base import Var from reflex_components_core.el.element import Element @@ -140,3 +141,29 @@ class BaseHTML(Element): ) title: Var[str] = field(doc="Defines a tooltip for the element.") + + +class RawTextBaseHTML(BaseHTML): + """Base class for HTML elements with raw-text (RCDATA) content models. + + Raw-text elements (``