diff --git a/README.md b/README.md index 700c82d..68ace52 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,8 @@ KernelScript addresses these problems through revolutionary language features: ✅ **Zero-boilerplate shared state** - Maps are automatically accessible across all programs as regular global variables in a programming language +✅ **Ergonomic map idioms** - Declaration-as-condition (`if (var s = m[k]) { s.field = ... }`) and compound assignment on map indices (`m[k].count += 1`) compile down to a single presence-checked lookup with in-place mutation, no manual write-back + ✅ **Builtin kfunc support** - Define full-privilege kernel functions that eBPF programs can call directly, automatically generating kernel modules and BTF registrations ✅ **Unified error handling** - C-style integer throw/catch works seamlessly in both eBPF and userspace contexts, unlike complex Result types @@ -213,16 +215,38 @@ fn handle_action(action: FilterAction) -> xdp_action { } } -// Map lookup and update patterns +// Map lookup and update patterns — declaration-as-condition binds +// `count` only inside the truthy branch; one map lookup, no extra +// presence-check variable. fn lookup_or_create(ip: IpAddress) -> Counter { - var count = connection_count[ip] - if (count != none) { + if (var count = connection_count[ip]) { return count // Entry exists } else { connection_count[ip] = 1 // Create new entry return 1 } } + +// Declaration-as-condition: bind only inside the truthy branch. +// For struct-valued maps, the bound name is the lookup pointer, so +// field access auto-derefs and the generated eBPF performs in-place +// mutation against the underlying entry — no write-back needed. +pin var ip_stats : hash(1024) + +@helper +fn record_packet(ip: IpAddress, size: PacketSize) { + if (var stats = ip_stats[ip]) { + stats.size = size + } else { + ip_stats[ip] = PacketInfo { src_ip: ip, dst_ip: 0, protocol: 0, size: size } + } +} + +// Compound assignment indexes into struct-valued maps directly: +@helper +fn bump_size(ip: IpAddress, delta: PacketSize) { + ip_stats[ip].size += delta // emits a presence-checked ptr->size += delta +} ``` ### Multi-Program Coordination diff --git a/SPEC.md b/SPEC.md index 8e1e2cf..04483a9 100644 --- a/SPEC.md +++ b/SPEC.md @@ -1476,8 +1476,7 @@ fn ddos_protection(ctx: *xdp_md) -> xdp_action { @tc("ingress") fn connection_tracker(ctx: *__sk_buff) -> i32 { - var tcp_info = extract_tcp_info(ctx) // Reuse same helper - if (tcp_info != null) { + if (var tcp_info = extract_tcp_info(ctx)) { // Reuse same helper track_connection(tcp_info.src_port, tcp_info.dst_port) } return 0 // TC_ACT_OK @@ -2357,9 +2356,9 @@ fn ebpf_pointer_usage(ctx: *xdp_md) -> xdp_action { } } - // Dynptr-backed pointers (transparent to user) - var log_buffer: *u8 = event_log.reserve(256) // Returns dynptr-backed pointer - if (log_buffer != null) { + // Dynptr-backed pointers (transparent to user) — `log_buffer` is the + // *u8 returned by reserve(), in scope only inside the truthy branch. + if (var log_buffer = event_log.reserve(256)) { // Regular pointer operations - compiler uses dynptr API internally log_buffer[0] = EVENT_TYPE_PACKET write_packet_summary(log_buffer + 1, packet_data, 255) @@ -2428,15 +2427,14 @@ var flow_map : hash(1024) @helper fn map_pointer_operations(flow_key: FlowKey) { - // Map lookup returns pointer to value - var flow_data = flow_map[flow_key] - - if (flow_data != none) { + // Declaration-as-condition: a single map lookup; `flow_data` is the + // returned pointer, in scope only inside the truthy branch. + if (var flow_data = flow_map[flow_key]) { // Direct modification through pointer flow_data->packet_count += 1 flow_data->byte_count += packet_size flow_data->last_seen = bpf_ktime_get_ns() - + // Compiler tracks map value lifetime // flow_data becomes invalid after certain map operations } @@ -2605,9 +2603,8 @@ fn egress_monitor(ctx: *__sk_buff) -> i32 { fn security_analyzer(ctx: LsmContext) -> i32 { var flow_key = extract_flow_key_from_socket(ctx)? - // Check global flow statistics - if (global_flows[flow_key] != null) { - var flow_stats = global_flows[flow_key] + // Check global flow statistics — single lookup via IfLet + if (var flow_stats = global_flows[flow_key]) { if (flow_stats.is_suspicious()) { security_events.submit(SecurityEvent { event_type: EVENT_TYPE_SUSPICIOUS_CONNECTION, @@ -2617,7 +2614,7 @@ fn security_analyzer(ctx: LsmContext) -> i32 { return -EPERM // Block connection } } - + return 0 // Allow connection } ``` @@ -2769,12 +2766,10 @@ var flow_stats : hash(1024) @helper fn update_flow_stats(flow_id: u32, packet_size: u32) { - var stats = flow_stats[flow_id] - if (stats != null) { - stats.packet_count += 1 - stats.total_bytes += packet_size - stats.avg_packet_size = stats.total_bytes / stats.packet_count - } + // Compound assignment on a struct-field of a map value emits a single + // presence-checked map lookup and mutates in place; see §6.2.5. + flow_stats[flow_id].packet_count += 1 + flow_stats[flow_id].total_bytes += packet_size } ``` @@ -2825,7 +2820,73 @@ fn main() -> i32 { } ``` -#### 6.2.5 Performance and Code Generation +#### 6.2.5 Compound Assignment with Map Indexing + +KernelScript extends compound assignment to map index expressions, so a +counter update against a map value can be written without an intermediate +variable or an explicit write-back. + +##### 6.2.5.1 Scalar map values + +When the map's value type is an integer, `m[k] op= rhs` reads the current +entry, applies `op`, and writes the result back. If the entry is absent +the read yields zero, so the operation creates the entry on first use. + +```kernelscript +var packet_counts : hash(1024) + +@xdp +fn rate_limiter(ctx: *xdp_md) -> xdp_action { + var src_ip = extract_src_ip(ctx) + packet_counts[src_ip] += 1 // read-modify-write; creates entry if absent + return XDP_PASS +} +``` + +The supported operators are `+=`, `-=`, `*=`, `/=`, `%=`. The map's value +type must be one of the integer primitives. + +##### 6.2.5.2 Struct-field map values + +When the map's value type is a struct, `m[k].field op= rhs` mutates a +single field of an existing entry in place. The compiler lowers the form +to a presence-checked pointer mutation: + +```kernelscript +struct PacketStats { + count: u64, + total_bytes: u64, +} + +var ip_stats : hash(1024) + +@xdp +fn observe(ctx: *xdp_md) -> xdp_action { + var ip = extract_src_ip(ctx) + var len = packet_len(ctx) + ip_stats[ip].count += 1 + ip_stats[ip].total_bytes += len + return XDP_PASS +} +``` + +Semantics: + +- **Map identifier required.** The left-hand side must be `IDENT[expr].field op= rhs`; + arbitrary LHS expressions are not allowed. +- **Value type must be a struct.** `field` is resolved against the map's value + struct definition; an unknown field is a compile-time error. +- **Field type drives `op`.** The named field must be one of the integer + primitives; the right-hand side must be assignment-compatible with the field type. +- **Presence check, no creation.** If the entry is absent the statement is a + no-op — unlike scalar `m[k] op= rhs`, the struct-field form does *not* + create a default entry. To handle the missing case, pair it with an + explicit `else` using the declaration-as-condition form (see §7.5.1). +- **Single map lookup.** Generated code performs one `bpf_map_lookup_elem`, + guards on the returned pointer, and writes through it + (`if (p) { p->field = p->field op rhs; }`); there is no separate write-back. + +#### 6.2.6 Performance and Code Generation Compound assignments generate efficient code in both contexts: @@ -2952,7 +3013,7 @@ fn packet_filter(ctx: *xdp_md) -> action: xdp_action { // Userspace functions with named returns fn lookup_counter(ip: u32) -> counter_ptr: *u64 { - if (counters[ip] == none) { + if (counters[ip] == null) { counters[ip] = 0 } counter_ptr = &counters[ip] @@ -3220,6 +3281,13 @@ fn main(args: Args) -> i32 { ### 7.5 Control Flow Statements #### 7.5.1 Conditional Statements + +KernelScript provides two `if` forms: a standard expression-condition +form and a *declaration-as-condition* form that combines a single-use +binding with a presence check. + +##### 7.5.1.1 Expression-condition form + ```kernelscript // Conditional statements if (condition) { @@ -3231,6 +3299,48 @@ if (condition) { } ``` +##### 7.5.1.2 Declaration-as-condition form (`if (var name = expr)`) + +```kernelscript +if (var name = expr) { + // then-branch: `name` is in scope and bound to `expr`'s value +} else { + // else-branch: `name` is *not* in scope +} +``` + +The branch is taken iff `expr` produces a *present* value: + +- **Map index** (`m[k]`): present iff the entry exists. The bound name is + the lookup pointer, so field access auto-derefs and field assignments + mutate the underlying map entry in place — no explicit write-back is + needed: + + ```kernelscript + if (var stats = ip_stats[ip]) { + stats.count = stats.count + 1 // writes through the lookup pointer + } else { + ip_stats[ip] = PacketStats { count: 1, total_bytes: 0 } + } + ``` + +- **Pointer-returning expression**: present iff non-null. Useful with + helpers and kfuncs that may return `null`. + +Semantics: + +- **Single evaluation.** `expr` is evaluated exactly once; its presence + test guards both branches. +- **Scoping.** `name` is in scope only inside the then-branch. Referencing + it from the else-branch (or after the `if`) is a compile-time error. +- **No reassignment.** `name` shadows nothing visible to the else-branch + and may shadow an outer binding only inside the then-branch. +- **Else is optional.** As with the expression-condition form, the + `else` branch may be omitted. +- **Lowering.** The form lowers to a single `bpf_map_lookup_elem` (or the + underlying pointer-returning call), a null check, and the chosen + branch — there is no second lookup. + #### 7.5.2 Match Expressions KernelScript provides `match` expressions for efficient multi-way branching. Match is an expression that returns a value and can be used anywhere an expression is expected. @@ -3623,9 +3733,8 @@ pin var global_config : array(64) fn security_filter(ctx: LsmContext) -> i32 { var flow_key = extract_flow_key_from_socket(ctx) - // Check global flow statistics for threat detection - if (global_flows[flow_key] != none) { - var flow_stats = global_flows[flow_key] + // Check global flow statistics for threat detection — single lookup + if (var flow_stats = global_flows[flow_key]) { if (flow_stats.is_suspicious()) { global_events.submit(EVENT_THREAT_DETECTED { flow_key }) return -EPERM // Block connection @@ -3666,8 +3775,7 @@ fn start_coordinator() -> i32 { fn process_events(coordinator: *SystemCoordinator) { // Process events from all programs - var event = coordinator->global_events.read() - if (event != null) { + if (var event = coordinator->global_events.read()) { if (event.event_type == EVENT_PACKET_PROCESSED) { print("Processed packet for flow: ", event.flow_key) } else if (event.event_type == EVENT_THREAT_DETECTED) { @@ -3818,17 +3926,17 @@ var event_log : hash(1024) @helper fn transparent_dynptr_usage(event_data: *u8, data_len: u32) { - // User writes simple pointer code - var log_entry: *u8 = event_log.reserve(data_len + 16) // Dynptr-backed pointer - if (log_entry != null) { + // User writes simple pointer code — IfLet binds the *u8 returned by + // reserve() only inside the truthy branch. + if (var log_entry = event_log.reserve(data_len + 16)) { // Regular pointer operations - compiler uses dynptr API internally var header = log_entry as *EventHeader header->timestamp = bpf_ktime_get_ns() header->data_len = data_len - + // Memory copy using pointer arithmetic memory_copy(event_data, log_entry + 16, data_len) - + event_log.submit(log_entry) // Compiler ensures proper cleanup } } @@ -3915,15 +4023,14 @@ var cache_map : hash(1024) @helper fn map_lifetime_safety(key: u32) { - var cache_entry = cache_map[key] - if (cache_entry != none) { + if (var cache_entry = cache_map[key]) { // Compiler tracks that cache_entry is valid here cache_entry->access_count += 1 cache_entry->last_access = bpf_ktime_get_ns() - + // Compiler warns/errors if cache_entry used after invalidating operations cache_map[other_key] = other_value // Invalidates cache_entry - + // ❌ Compiler error: "Use of potentially invalidated map value pointer" // cache_entry->access_count += 1 } @@ -3971,12 +4078,11 @@ fn kernel_side_processing(ctx: *xdp_md) -> xdp_action { var packet_data = ctx->data() // Shared memory through maps - safe across contexts - var shared_buffer = shared_map[0] - if (shared_buffer != none) { + if (var shared_buffer = shared_map[0]) { shared_buffer->kernel_processed_count += 1 memory_copy(packet_data, shared_buffer->data, min(packet_len, 64)) } - + return XDP_PASS } @@ -3984,14 +4090,13 @@ fn kernel_side_processing(ctx: *xdp_md) -> xdp_action { fn userspace_processing() -> i32 { // ❌ Cannot access kernel context pointers directly // var packet_data = some_kernel_context.data() // Compilation error - + // ✅ Access through shared maps - var shared_buffer = shared_map[0] - if (shared_buffer != none) { + if (var shared_buffer = shared_map[0]) { shared_buffer->userspace_processed_count += 1 process_shared_data(shared_buffer->data) } - + return 0 } ``` @@ -4303,14 +4408,47 @@ statement = expression_statement | assignment_statement | declaration_statement try_statement | throw_statement | defer_statement expression_statement = expression -assignment_statement = identifier assignment_operator expression -assignment_operator = "=" | "+=" | "-=" | "*=" | "/=" | "%=" + +assignment_statement = simple_assignment | compound_assignment | field_assignment | + arrow_assignment | index_assignment | compound_index_assignment | + compound_field_index_assignment + +simple_assignment = identifier "=" expression (* x = e *) +compound_assignment = identifier compound_operator expression (* x op= e *) +field_assignment = primary_expression "." identifier "=" expression (* o.field = e *) +arrow_assignment = primary_expression "->" identifier "=" expression (* p->field = e *) +index_assignment = expression "[" expression "]" "=" expression (* m[k] = e *) +compound_index_assignment = expression "[" expression "]" compound_operator expression + (* m[k] op= e: + scalar map values; reads, applies op, writes back; + absent entries read as 0, so the form creates an + entry on first use. See §6.2.5.1. *) +compound_field_index_assignment = identifier "[" expression "]" "." identifier compound_operator expression + (* m[k].field op= e: + struct-valued map; lowers to a single + bpf_map_lookup_elem + null-checked + ptr->field op= e; absent entries are a no-op + (no entry is created). See §6.2.5.2. *) + +assignment_operator = "=" | compound_operator +compound_operator = "+=" | "-=" | "*=" | "/=" | "%=" declaration_statement = "var" identifier [ ":" type_annotation ] "=" expression -if_statement = "if" "(" expression ")" "{" statement_list "}" - { "else" "if" "(" expression ")" "{" statement_list "}" } - [ "else" "{" statement_list "}" ] +if_statement = expression_if | iflet_if + +expression_if = "if" "(" expression ")" "{" statement_list "}" + { "else" "if" "(" expression ")" "{" statement_list "}" } + [ "else" "{" statement_list "}" ] + +iflet_if = "if" "(" "var" identifier "=" expression ")" "{" statement_list "}" + [ "else" ( "{" statement_list "}" | iflet_if | expression_if ) ] + (* Declaration-as-condition: the right-hand side is evaluated once; + the then-branch is taken iff the value is *present* (a map hit + or a non-null pointer). `identifier` is bound only inside the + then-branch. For map-index right-hand sides the binding is the + lookup pointer (field access auto-derefs, field writes mutate + the underlying entry in place). See §7.5.1.2. *) for_statement = "for" "(" identifier "in" expression ".." expression ")" "{" statement_list "}" | "for" "(" identifier "," identifier ")" "in" expression "{" statement_list "}" diff --git a/examples/map_operations_demo.ks b/examples/map_operations_demo.ks index c6df656..0cdb824 100644 --- a/examples/map_operations_demo.ks +++ b/examples/map_operations_demo.ks @@ -58,7 +58,7 @@ struct ArrayElement { // Safe concurrent read access - multiple programs can read simultaneously var counter = global_counter[key] - if (counter != none) { + if (counter != null) { // High-frequency lookup pattern - will generate optimization suggestions for (i in 0..100) { var _ = global_counter[key + i] @@ -70,16 +70,13 @@ struct ArrayElement { // Per-CPU access for maximum performance var cpu_id = 0 - var data = percpu_data[cpu_id] - if (data != none) { + if (var data = percpu_data[cpu_id]) { data.local_counter = data.local_counter + 1 - percpu_data[cpu_id] = data } else { - var new_data = PerCpuData { + percpu_data[cpu_id] = PerCpuData { local_counter: 1, temp_storage: [0], } - percpu_data[cpu_id] = new_data } return XDP_PASS @@ -92,7 +89,7 @@ fn stats_updater(ctx: *__sk_buff) -> i32 { // Potential write conflict with other programs var stats = shared_stats[ifindex] - if (stats == none) { + if (stats == null) { stats = Statistics { packet_count: 0, byte_count: 0, @@ -116,10 +113,8 @@ fn stats_updater(ctx: *__sk_buff) -> i32 { // Batch operation pattern - will be detected as batch access for (i in 0..20) { var batch_key = ifindex + i - var entry = shared_stats[batch_key] - if (entry != null) { + if (var entry = shared_stats[batch_key]) { entry.packet_count = entry.packet_count + 1 - shared_stats[batch_key] = entry } } @@ -132,8 +127,7 @@ fn event_logger(ctx: *trace_event_raw_sys_enter) -> i32 { // Ring buffer output - single writer recommended try { // Reserve space in the ring buffer - var reserved = event_stream.reserve() - if (reserved != null) { + if (var reserved = event_stream.reserve()) { // Successfully reserved space - populate event data inline reserved->timestamp = 123456 // Fake timestamp reserved->event_type = ctx->id // Use syscall ID from sys_enter context @@ -157,19 +151,16 @@ fn event_logger(ctx: *trace_event_raw_sys_enter) -> i32 { fn data_processor(file: *file, buf: *u8, count: size_t, pos: *i64) -> i32 { // Sequential access pattern - will be detected and optimized for (i in 0..32) { - var element = sequential_data[i] - if (element != none) { + if (var element = sequential_data[i]) { if (!element.processed) { element.value = element.value * 2 element.processed = true - sequential_data[i] = element } } else { - var new_element = ArrayElement { + sequential_data[i] = ArrayElement { value: i, processed: false, } - sequential_data[i] = new_element } } diff --git a/examples/maps_demo.ks b/examples/maps_demo.ks index 1ee263e..5a36db9 100644 --- a/examples/maps_demo.ks +++ b/examples/maps_demo.ks @@ -70,28 +70,21 @@ fn get_timestamp() -> u64 { var cpu_id = get_cpu_id() cpu_counters[cpu_id] = cpu_counters[cpu_id] + 1 - // Update IP statistics - elegant truthy/falsy pattern - var stats = ip_stats[src_ip] - if (stats != none) { - // stats is truthy - entry exists, update it + // Update IP statistics - in-place mutation when entry exists + if (var stats = ip_stats[src_ip]) { stats.count = stats.count + 1 stats.total_bytes = stats.total_bytes + packet_len stats.last_seen = get_timestamp() - ip_stats[src_ip] = stats } else { - // stats is falsy - no entry, create new one - var new_stats = PacketStats { + ip_stats[src_ip] = PacketStats { count: 1, total_bytes: packet_len, last_seen: get_timestamp() } - ip_stats[src_ip] = new_stats } - - // Check recent connections - var recent = recent_connections[src_ip] - if (recent != none) { - // Log repeated connection + + // Log repeated connections + if (recent_connections[src_ip] != null) { event_log[0] = 1 } diff --git a/examples/object_allocation.ks b/examples/object_allocation.ks index 1b25aea..385db47 100644 --- a/examples/object_allocation.ks +++ b/examples/object_allocation.ks @@ -21,7 +21,7 @@ var conn_tracker : hash(1024) // Look up existing connection stats var stats = conn_tracker[src_ip] - if (stats == none) { + if (stats == null) { // First packet from this IP - allocate new stats object stats = new ConnStats() if (stats == null) { diff --git a/examples/rate_limiter.ks b/examples/rate_limiter.ks index 6d6bc8e..6c48f1a 100644 --- a/examples/rate_limiter.ks +++ b/examples/rate_limiter.ks @@ -21,7 +21,7 @@ config network { var src_ip = 0x7F000001 // Placeholder IP (127.0.0.1) // Update the count - if (packet_counts[src_ip] != none) { + if (packet_counts[src_ip] != null) { packet_counts[src_ip] += 1 } else { packet_counts[src_ip] = 0 diff --git a/examples/ringbuf_demo.ks b/examples/ringbuf_demo.ks index 907549d..d485d45 100644 --- a/examples/ringbuf_demo.ks +++ b/examples/ringbuf_demo.ks @@ -44,15 +44,14 @@ fn get_timestamp() -> u64 { @xdp fn network_monitor(ctx: *xdp_md) -> xdp_action { var key: u32 = 0 var stat = stats[key] - if (stat == none) { + if (stat == null) { var init_stat = Stats { events_submitted: 0, events_dropped: 0, buffer_full_count: 0 } stats[key] = init_stat stat = stats[key] } // Try to reserve space in ring buffer - var reserved = network_events.reserve() - if (reserved != null) { + if (var reserved = network_events.reserve()) { // Successfully reserved space - populate event data inline reserved->timestamp = get_timestamp() reserved->event_type = 1 // PACKET_RECEIVED @@ -77,8 +76,7 @@ fn get_timestamp() -> u64 { // Security monitoring program @probe("sys_openat") fn security_monitor(dfd: i32, filename: *u8, flags: i32, mode: u16) -> i32 { - var reserved = security_events.reserve() - if (reserved != null) { + if (var reserved = security_events.reserve()) { // Successfully reserved space - populate security event inline reserved->timestamp = get_timestamp() reserved->severity = 2 // Medium severity diff --git a/examples/safety_demo.ks b/examples/safety_demo.ks index 38673e3..7653130 100644 --- a/examples/safety_demo.ks +++ b/examples/safety_demo.ks @@ -109,8 +109,7 @@ fn array_validation_demo(ctx: *xdp_md) -> xdp_action { // Safe map access var key: u32 = 1 - var count = packet_stats[key] - if (count != none) { + if (var count = packet_stats[key]) { packet_stats[key] = count + 1 } else { packet_stats[key] = 1 diff --git a/examples/test_error_handling.ks b/examples/test_error_handling.ks index c77814f..69fe469 100644 --- a/examples/test_error_handling.ks +++ b/examples/test_error_handling.ks @@ -22,7 +22,7 @@ fn process_key(key: u32) -> u32 { try { // Check if key exists (expected absence - use null) var value = test_map[key] - if (value == none) { + if (value == null) { // Key doesn't exist - create default value (expected pattern) var default_value = 42 test_map[key] = default_value diff --git a/examples/type_checking.ks b/examples/type_checking.ks index e504dae..40ae880 100644 --- a/examples/type_checking.ks +++ b/examples/type_checking.ks @@ -66,9 +66,7 @@ fn classify_protocol(proto: u8) -> ProtocolType { @helper fn update_statistics(header: PacketHeader) { // Type checker validates map operations and key/value types - var current_count = connection_stats[header.src_ip] - - if (current_count != null) { + if (var current_count = connection_stats[header.src_ip]) { // Type checker ensures arithmetic on compatible types connection_stats[header.src_ip] = current_count + 1 } else { diff --git a/examples/types_demo.ks b/examples/types_demo.ks index a0779ab..eb8c47d 100644 --- a/examples/types_demo.ks +++ b/examples/types_demo.ks @@ -63,8 +63,7 @@ fn extract_packet_info(ctx: *xdp_md) -> *PacketInfo { @helper fn get_filter_action(info: PacketInfo) -> FilterAction { // Look up in the filter map - var action = packet_filter[info] - if (action != none) { + if (var action = packet_filter[info]) { return action } else { return FILTER_ACTION_ALLOW @@ -85,17 +84,15 @@ fn protocol_from_u8(proto_num: u8) -> Protocol { @helper fn update_stats(info: PacketInfo) { // Update connection count - var current_count = connection_count[info.src_ip] - if (current_count != none) { + if (var current_count = connection_count[info.src_ip]) { connection_count[info.src_ip] = current_count + 1 } else { connection_count[info.src_ip] = 1 } - + // Update protocol stats var proto = protocol_from_u8(info.protocol) - var stats = protocol_stats[proto] - if (stats != none) { + if (var stats = protocol_stats[proto]) { protocol_stats[proto] = stats + 1 } else { protocol_stats[proto] = 1 @@ -105,9 +102,7 @@ fn update_stats(info: PacketInfo) { // Program using all the new types @xdp fn packet_inspector(ctx: *xdp_md) -> xdp_action { // Extract packet information - var packet_info = extract_packet_info(ctx) - - if (packet_info != none) { + if (var packet_info = extract_packet_info(ctx)) { // Update statistics update_stats(*packet_info) diff --git a/src/ast.ml b/src/ast.ml index 3ff6ae4..7c7873d 100644 --- a/src/ast.ml +++ b/src/ast.ml @@ -81,8 +81,6 @@ type bpf_type = (* Ring buffer reference type - represents a ring buffer for dispatch *) | RingbufRef of bpf_type (* value type *) | Ringbuf of bpf_type * int (* value_type, size - ring buffer object *) - (* None type - represents missing/absent values *) - | NoneType (* Null type - represents null pointers, compatible with any pointer type *) | Null @@ -187,7 +185,6 @@ type literal = | BoolLit of bool | ArrayLit of array_init_style (* Enhanced array initialization *) | NullLit - | NoneLit (** Array initialization styles *) and array_init_style = @@ -285,6 +282,8 @@ and stmt_desc = | Assignment of string * expr | CompoundAssignment of string * binary_op * expr (* var op= expr *) | CompoundIndexAssignment of expr * expr * binary_op * expr (* map[key] op= expr *) + | CompoundFieldIndexAssignment of expr * expr * string * binary_op * expr + (* map[key].field op= expr *) | FieldAssignment of expr * string * expr (* object.field = value *) | ArrowAssignment of expr * string * expr (* pointer->field = value *) | IndexAssignment of expr * expr * expr (* map[key] = value *) @@ -292,6 +291,10 @@ and stmt_desc = | ConstDeclaration of string * bpf_type option * expr (* const name : type = value *) | Return of expr option | If of expr * statement list * statement list option + | IfLet of string * expr * statement list * statement list option + (* if (var name = expr) { then_stmts } else { else_stmts } + Truthy iff expr is "present": map hit, non-null pointer return, etc. + `name` is bound only inside then_stmts. *) | For of string * expr * expr * statement list | ForIter of string * string * expr * statement list (* for (index, value) in expr.iter() { ... } *) | While of expr * statement list @@ -710,7 +713,6 @@ let rec string_of_bpf_type = function | ProgramHandle -> "ProgramHandle" | RingbufRef value_type -> Printf.sprintf "ringbuf_ref<%s>" (string_of_bpf_type value_type) | Ringbuf (value_type, size) -> Printf.sprintf "ringbuf<%s>(%d)" (string_of_bpf_type value_type) size - | NoneType -> "none" | Null -> "null" let rec string_of_literal = function @@ -726,7 +728,6 @@ let rec string_of_literal = function Printf.sprintf "[%s]" (String.concat ", " (List.map string_of_literal literals)) | ArrayLit (ZeroArray) -> "[]" | NullLit -> "null" - | NoneLit -> "none" let string_of_binary_op = function | Add -> "+" @@ -814,6 +815,10 @@ and string_of_stmt stmt = Printf.sprintf "%s %s= %s;" name (string_of_binary_op op) (string_of_expr expr) | CompoundIndexAssignment (map_expr, key_expr, op, value_expr) -> Printf.sprintf "%s[%s] %s= %s;" (string_of_expr map_expr) (string_of_expr key_expr) (string_of_binary_op op) (string_of_expr value_expr) + | CompoundFieldIndexAssignment (map_expr, key_expr, field, op, value_expr) -> + Printf.sprintf "%s[%s].%s %s= %s;" + (string_of_expr map_expr) (string_of_expr key_expr) field + (string_of_binary_op op) (string_of_expr value_expr) | FieldAssignment (obj_expr, field, value_expr) -> Printf.sprintf "%s.%s = %s;" (string_of_expr obj_expr) field (string_of_expr value_expr) | ArrowAssignment (obj_expr, field, value_expr) -> @@ -842,10 +847,18 @@ and string_of_stmt stmt = let then_str = String.concat " " (List.map string_of_stmt then_stmts) in let else_str = match else_opt with | None -> "" - | Some else_stmts -> + | Some else_stmts -> " else { " ^ String.concat " " (List.map string_of_stmt else_stmts) ^ " }" in Printf.sprintf "if (%s) { %s }%s" (string_of_expr cond) then_str else_str + | IfLet (name, expr, then_stmts, else_opt) -> + let then_str = String.concat " " (List.map string_of_stmt then_stmts) in + let else_str = match else_opt with + | None -> "" + | Some else_stmts -> + " else { " ^ String.concat " " (List.map string_of_stmt else_stmts) ^ " }" + in + Printf.sprintf "if (var %s = %s) { %s }%s" name (string_of_expr expr) then_str else_str | For (var, start, end_, body) -> let body_str = String.concat " " (List.map string_of_stmt body) in Printf.sprintf "for (%s in %s..%s) { %s }" diff --git a/src/ebpf_c_codegen.ml b/src/ebpf_c_codegen.ml index 5747f60..5574b16 100644 --- a/src/ebpf_c_codegen.ml +++ b/src/ebpf_c_codegen.ml @@ -1230,8 +1230,7 @@ let rec generate_c_value ?(auto_deref_map_access=false) ctx ir_val = | IRLiteral (BoolLit b) -> if b then "1" else "0" | IRLiteral (CharLit c) -> sprintf "'%c'" c | IRLiteral (NullLit) -> "NULL" - | IRLiteral (NoneLit) -> "0" - | IRLiteral (StringLit s) -> + | IRLiteral (StringLit s) -> (* Generate string literal as struct initialization *) (match ir_val.val_type with | IRStr size -> @@ -1263,7 +1262,6 @@ let rec generate_c_value ?(auto_deref_map_access=false) ctx ir_val = | Ast.CharLit c -> sprintf "'%c'" c | Ast.StringLit s -> sprintf "\"%s\"" (escape_c_string s) | Ast.NullLit -> "NULL" - | Ast.NoneLit -> "0" | Ast.ArrayLit _ -> "{0}" (* Nested arrays simplified *) in "{" ^ fill_str ^ "}" @@ -1275,7 +1273,6 @@ let rec generate_c_value ?(auto_deref_map_access=false) ctx ir_val = | Ast.CharLit c -> sprintf "'%c'" c | Ast.StringLit s -> sprintf "\"%s\"" (escape_c_string s) | Ast.NullLit -> "NULL" - | Ast.NoneLit -> "0" | Ast.ArrayLit _ -> "{0}" (* Nested arrays simplified *) ) elements in if List.length elements = 0 then @@ -1433,29 +1430,43 @@ let generate_c_expression ctx ir_expr = let index_str = generate_c_value ctx right in sprintf "%s.data[%s]" array_str index_str | _ -> - (* Check for none comparisons first *) + (* `null` comparisons against a map-access lower to a presence + check against the underlying lookup pointer (or against the + value directly when it is already a pointer), so + `if (var x = map[k])` and `entry != null` produce correct C + without an extra dereference. *) + let is_absence_lit = function + | IRLiteral (Ast.NullLit) -> true + | _ -> false + in (match left.value_desc, op, right.value_desc with - | _, IREq, IRLiteral (Ast.NoneLit) - | IRLiteral (Ast.NoneLit), IREq, _ -> - (* Comparison with none: check if pointer is NULL *) - let non_none_val = if left.value_desc = IRLiteral (Ast.NoneLit) then right else left in - (* For IRMapAccess, use the underlying pointer directly for NULL check *) - let val_str = (match non_none_val.value_desc with + | _, IREq, _ when is_absence_lit right.value_desc -> + let val_str = (match left.value_desc with + | IRMapAccess (_, _, (underlying_desc, underlying_type)) -> + let underlying_val = { value_desc = underlying_desc; val_type = underlying_type; stack_offset = None; bounds_checked = false; val_pos = left.val_pos } in + generate_c_value ~auto_deref_map_access:false ctx underlying_val + | _ -> generate_c_value ctx left) in + sprintf "(%s == NULL)" val_str + | _, IREq, _ when is_absence_lit left.value_desc -> + let val_str = (match right.value_desc with | IRMapAccess (_, _, (underlying_desc, underlying_type)) -> - let underlying_val = { value_desc = underlying_desc; val_type = underlying_type; stack_offset = None; bounds_checked = false; val_pos = non_none_val.val_pos } in + let underlying_val = { value_desc = underlying_desc; val_type = underlying_type; stack_offset = None; bounds_checked = false; val_pos = right.val_pos } in generate_c_value ~auto_deref_map_access:false ctx underlying_val - | _ -> generate_c_value ctx non_none_val) in + | _ -> generate_c_value ctx right) in sprintf "(%s == NULL)" val_str - | _, IRNe, IRLiteral (Ast.NoneLit) - | IRLiteral (Ast.NoneLit), IRNe, _ -> - (* Not-equal comparison with none: check if pointer is not NULL *) - let non_none_val = if left.value_desc = IRLiteral (Ast.NoneLit) then right else left in - (* For IRMapAccess, use the underlying pointer directly for NULL check *) - let val_str = (match non_none_val.value_desc with + | _, IRNe, _ when is_absence_lit right.value_desc -> + let val_str = (match left.value_desc with | IRMapAccess (_, _, (underlying_desc, underlying_type)) -> - let underlying_val = { value_desc = underlying_desc; val_type = underlying_type; stack_offset = None; bounds_checked = false; val_pos = non_none_val.val_pos } in + let underlying_val = { value_desc = underlying_desc; val_type = underlying_type; stack_offset = None; bounds_checked = false; val_pos = left.val_pos } in generate_c_value ~auto_deref_map_access:false ctx underlying_val - | _ -> generate_c_value ctx non_none_val) in + | _ -> generate_c_value ctx left) in + sprintf "(%s != NULL)" val_str + | _, IRNe, _ when is_absence_lit left.value_desc -> + let val_str = (match right.value_desc with + | IRMapAccess (_, _, (underlying_desc, underlying_type)) -> + let underlying_val = { value_desc = underlying_desc; val_type = underlying_type; stack_offset = None; bounds_checked = false; val_pos = right.val_pos } in + generate_c_value ~auto_deref_map_access:false ctx underlying_val + | _ -> generate_c_value ctx right) in sprintf "(%s != NULL)" val_str | _ -> (* Regular binary operation - auto-dereference map access for operands *) @@ -1936,6 +1947,10 @@ and generate_c_instruction ctx ir_instr = (* Other string expressions (concatenation, etc.) *) let init_str = generate_c_expression ctx init_expr in emit_line ctx (sprintf "%s %s = %s;" type_str var_name init_str) + | IRPointer _, IRValue src_val when (match src_val.value_desc with IRMapAccess _ -> true | _ -> false) -> + (* Pointer-typed variable initialized from a map lookup: keep the pointer. *) + let init_str = generate_c_value ~auto_deref_map_access:false ctx src_val in + emit_line ctx (sprintf "%s %s = %s;" type_str var_name init_str) | _ -> (* Regular non-string assignment *) let init_str = generate_c_expression ctx init_expr in diff --git a/src/evaluator.ml b/src/evaluator.ml index f0268e8..49a5a37 100644 --- a/src/evaluator.ml +++ b/src/evaluator.ml @@ -37,7 +37,6 @@ type runtime_value = | ContextValue of string * (string * runtime_value) list | NullValue (* Simple null value representation *) | UnitValue - | None (* Sentinel value for map lookup failures and missing values *) (** Additional exceptions that depend on runtime_value *) exception Return_value of runtime_value @@ -54,7 +53,6 @@ let rec runtime_values_equal v1 v2 = | EnumValue (name1, val1), EnumValue (name2, val2) -> name1 = name2 && val1 = val2 | NullValue, NullValue -> true | UnitValue, UnitValue -> true - | None, None -> true | PointerValue addr1, PointerValue addr2 -> addr1 = addr2 | MapHandle name1, MapHandle name2 -> name1 = name2 | ArrayValue arr1, ArrayValue arr2 -> @@ -326,7 +324,6 @@ let rec string_of_runtime_value = function name ^ " = " ^ string_of_runtime_value value) fields)) | NullValue -> "null" | UnitValue -> "()" - | None -> "none" (** Convert literal to runtime value *) let runtime_value_of_literal = function @@ -335,7 +332,6 @@ let runtime_value_of_literal = function | CharLit c -> CharValue c | BoolLit b -> BoolValue b | NullLit -> NullValue (* null is represented as simple null value *) - | NoneLit -> None (* none is represented as none sentinel value *) | ArrayLit _literals -> (* TODO: Implement array literal evaluation *) failwith "Array literal evaluation not implemented yet" @@ -359,7 +355,6 @@ let is_truthy_value rv = | ContextValue (_, _) -> true (* context values are always truthy *) | NullValue -> false (* null is always falsy *) | UnitValue -> false (* unit value is falsy *) - | None -> false (* none is always falsy *) | ArrayValue _ -> failwith "Arrays cannot be used in boolean context" | StructValue _ -> failwith "Structs cannot be used in boolean context" @@ -405,15 +400,7 @@ let eval_binary_op left_val op right_val pos = | Ne, NullValue, _ -> BoolValue true | Eq, _, NullValue -> BoolValue false | Ne, _, NullValue -> BoolValue true - - (* None comparisons *) - | Eq, None, None -> BoolValue true - | Ne, None, None -> BoolValue false - | Eq, None, _ -> BoolValue false - | Ne, None, _ -> BoolValue true - | Eq, _, None -> BoolValue false - | Ne, _, None -> BoolValue true - + (* Logical operations *) | And, BoolValue l, BoolValue r -> BoolValue (l && r) | Or, BoolValue l, BoolValue r -> BoolValue (l || r) @@ -575,8 +562,8 @@ and eval_array_access ctx arr_expr idx_expr pos = (try Hashtbl.find map_store key_str with Not_found -> - (* For map access, return sentinel value for missing keys *) - None) + (* For map access, missing keys evaluate to null *) + NullValue) | _ -> (* Regular array access *) let arr_val = eval_expression ctx arr_expr in @@ -804,6 +791,11 @@ and eval_statement ctx stmt = in let result = eval_binary_op current_val op value_val stmt.stmt_pos in Hashtbl.replace map_store key_str result + + | CompoundFieldIndexAssignment (_, _, _, _, _) -> + (* The interpreter is used for compile-time evaluation only; + struct-field compound assignment on map values is a runtime construct. *) + eval_error "map[key].field op= rhs is not supported in the interpreter" stmt.stmt_pos | FieldAssignment (obj_expr, _field, value_expr) -> (* For evaluation purposes, treat config field assignment as no-op *) @@ -878,6 +870,21 @@ and eval_statement ctx stmt = (match else_opt with | Some else_stmts -> eval_statements ctx else_stmts | None -> ()) + + | IfLet (name, expr, then_stmts, else_opt) -> + let v = eval_expression ctx expr in + let present = is_truthy_value v in + if present then begin + let old = try Some (Hashtbl.find ctx.variables name) with Not_found -> None in + Hashtbl.replace ctx.variables name v; + eval_statements ctx then_stmts; + (match old with + | Some o -> Hashtbl.replace ctx.variables name o + | None -> Hashtbl.remove ctx.variables name) + end else + (match else_opt with + | Some else_stmts -> eval_statements ctx else_stmts + | None -> ()) | For (var, start_expr, end_expr, body) -> let start_val = eval_expression ctx start_expr in diff --git a/src/ir.ml b/src/ir.ml index d3f041f..773ace8 100644 --- a/src/ir.ml +++ b/src/ir.ml @@ -729,7 +729,6 @@ let rec ast_type_to_ir_type = function | ProgramHandle -> IRI32 (* Program handles are represented as file descriptors (i32) in IR to support error codes *) | Ringbuf (value_type, size) -> IRRingbuf (ast_type_to_ir_type value_type, size) (* Ring buffer object *) | RingbufRef _ -> IRU32 (* Ring buffer references are represented as pointers/handles (u32) in IR *) - | NoneType -> IRU32 (* None type represented as u32 sentinel value in IR *) | Null -> IRPointer (IRU32, {min_size = Some 0; max_size = Some 0; alignment = 1; nullable = true}) (* Null is represented as a nullable pointer in IR *) (* Helper function that preserves type aliases when converting AST types to IR types *) diff --git a/src/ir_generator.ml b/src/ir_generator.ml index 5e71e4b..746c008 100644 --- a/src/ir_generator.ml +++ b/src/ir_generator.ml @@ -67,6 +67,12 @@ type ir_context = { map_origin_variables: (string, (string * ir_value * (ir_value_desc * ir_type))) Hashtbl.t; (* var_name -> (map_name, key, underlying_info) *) (* Track inferred variable types for proper lookups *) variable_types: (string, ir_type) Hashtbl.t; (* var_name -> ir_type *) + (* Active IfLet bindings: source name -> synthetic IR name, for the duration + of the then-branch. Reads, simple assignments, and compound assignments + of the source name are rewritten to the synthetic name; the synthetic + name is what was actually declared in IR, so an outer variable of the + same name is never clobbered when the backend hoists declarations. *) + iflet_aliases: (string, string) Hashtbl.t; mutable current_program_type: program_type option; } @@ -92,6 +98,7 @@ let create_context ?(global_variables = []) ?(helper_functions = []) symbol_tabl tbl); map_origin_variables = Hashtbl.create 32; variable_types = Hashtbl.create 32; + iflet_aliases = Hashtbl.create 4; current_program_type = None; helper_functions = (let tbl = Hashtbl.create 16 in List.iter (fun helper_name -> Hashtbl.add tbl helper_name ()) helper_functions; @@ -234,11 +241,10 @@ let lower_literal lit pos = | StringLit s -> IRStr (max 1 (String.length s)) (* String literals get IRStr type *) | CharLit _ -> IRChar | BoolLit _ -> IRBool - | NullLit -> + | NullLit -> let bounds = make_bounds_info ~nullable:true () in IRPointer (IRU32, bounds) (* null literal as nullable pointer to u32 *) - | NoneLit -> IRU32 (* none literal as sentinel u32 value *) - | ArrayLit init_style -> + | ArrayLit init_style -> (* Handle enhanced array literal lowering *) (match init_style with | ZeroArray -> @@ -251,10 +257,9 @@ let lower_literal lit pos = | BoolLit _ -> IRBool | CharLit _ -> IRChar | StringLit _ -> IRPointer (IRU8, make_bounds_info ~nullable:false ()) - | NullLit -> + | NullLit -> let bounds = make_bounds_info ~nullable:true () in IRPointer (IRU32, bounds) - | NoneLit -> IRU32 (* none literal as sentinel u32 value *) | ArrayLit _ -> IRU32 (* Nested arrays default to u32 *) in IRArray (element_ir_type, 0, make_bounds_info ()) (* Size resolved during type unification *) @@ -271,10 +276,9 @@ let lower_literal lit pos = | CharLit _ -> IRChar | StringLit _ -> IRPointer (IRU8, make_bounds_info ~nullable:false ()) | ArrayLit _ -> IRU32 (* Nested arrays default to u32 *) - | NullLit -> + | NullLit -> let bounds = make_bounds_info ~nullable:true () in IRPointer (IRU32, bounds) - | NoneLit -> IRU32 (* none literal as sentinel u32 value *) in let bounds_info = make_bounds_info ~min_size:element_count ~max_size:element_count () in IRArray (element_ir_type, element_count, bounds_info)) @@ -314,7 +318,6 @@ let literal_to_ir_type = function | CharLit _ -> IRChar | StringLit _ -> IRPointer (IRU8, make_bounds_info ~nullable:false ()) | NullLit -> IRPointer (IRU32, make_bounds_info ~nullable:true ()) - | NoneLit -> IRU32 | ArrayLit _ -> IRU32 (* Default for arrays *) (** Unified AST to IR type conversion for basic types *) @@ -561,6 +564,14 @@ let expand_map_operation ctx map_name operation key_val value_val_opt pos = | _ -> failwith ("Unknown map operation: " ^ operation) +(** Resolve a source-level identifier to its current IR-level name. + Returns the synthetic name if the identifier is currently bound by an + enclosing IfLet, otherwise the name unchanged. *) +let resolve_iflet_alias ctx name = + match Hashtbl.find_opt ctx.iflet_aliases name with + | Some synth -> synth + | None -> name + (** Lower AST expressions to IR values *) let rec lower_expression ctx (expr : Ast.expr) = match expr.expr_desc with @@ -568,6 +579,7 @@ let rec lower_expression ctx (expr : Ast.expr) = lower_literal lit expr.expr_pos | Ast.Identifier name -> + let name = resolve_iflet_alias ctx name in (* Check if this is a map identifier *) if Hashtbl.mem ctx.maps name then (* For map identifiers, create a map reference *) @@ -1454,6 +1466,7 @@ and lower_statement ctx stmt = ()) | Ast.Assignment (name, expr) -> + let name = resolve_iflet_alias ctx name in let value = lower_expression ctx expr in (* Track if this assignment is from a map access *) @@ -1505,6 +1518,7 @@ and lower_statement ctx stmt = emit_instruction ctx assign_instr | Ast.CompoundAssignment (name, op, expr) -> + let name = resolve_iflet_alias ctx name in let value = lower_expression ctx expr in (* Check if this is a global variable assignment *) @@ -2039,11 +2053,108 @@ and lower_statement ctx stmt = in (* Generate IRIf instruction *) - let if_instr = make_ir_instruction + let if_instr = make_ir_instruction (IRIf (cond_val, !then_instructions, else_instrs_opt)) stmt.stmt_pos in emit_instruction ctx if_instr - + + | Ast.CompoundFieldIndexAssignment (map_expr, key_expr, field, op, value_expr) -> + (* Desugar `map[k].field op= rhs` to: + var __cidx_field_N = map[k] + if (__cidx_field_N != null) { + __cidx_field_N.field = __cidx_field_N.field op rhs + } + The synthetic name is fresh, so it cannot collide with any user + variable — we go straight to a plain Declaration + If rather than + routing through `Ast.IfLet` (whose alpha-rename machinery is only + needed when the binding name comes from user source). The + field-store lowers to a pointer-checked `ptr->field = ...` via + IRStructFieldAssignment. We look up the field's AST type from the + map's value-struct definition so the synthesized FieldAccess / + BinaryOp get the correct expr_type — without this the IR generator + defaults to IRU32, mis-sizing wider fields. *) + let pos = stmt.stmt_pos in + let synth_name = generate_temp_variable ctx "cidx_field" in + let map_name = match map_expr.expr_desc with + | Ast.Identifier mn -> mn + | _ -> failwith "Compound field-index assignment requires a map identifier" + in + let map_def = Hashtbl.find ctx.maps map_name in + let field_ast_type = + let rec resolve_struct_name = function + | Ast.Struct n | Ast.UserType n -> Some n + | _ -> None + and resolve t = + match resolve_struct_name t with + | Some n -> + (match Symbol_table.lookup_symbol ctx.symbol_table n with + | Some { kind = Symbol_table.TypeDef (Ast.StructDef (_, fs, _)); _ } -> + (try Some (List.assoc field fs) with Not_found -> None) + | _ -> None) + | None -> None + in + resolve map_def.ast_value_type + in + let mk_expr ?ty d = + { Ast.expr_desc = d; expr_pos = pos; expr_type = ty; + type_checked = false; program_context = None; map_scope = None } in + let access = mk_expr (Ast.ArrayAccess (map_expr, key_expr)) in + let tmp_id = mk_expr (Ast.Identifier synth_name) in + let cur_field = mk_expr ?ty:field_ast_type (Ast.FieldAccess (tmp_id, field)) in + let bin = mk_expr ?ty:field_ast_type (Ast.BinaryOp (cur_field, op, value_expr)) in + let store = + { Ast.stmt_desc = Ast.FieldAssignment (tmp_id, field, bin); stmt_pos = pos } in + let cond = + mk_expr ~ty:Ast.Bool + (Ast.BinaryOp (tmp_id, Ast.Ne, mk_expr (Ast.Literal Ast.NullLit))) in + lower_statement ctx + { Ast.stmt_desc = Ast.Declaration (synth_name, None, Some access); + stmt_pos = pos }; + lower_statement ctx + { Ast.stmt_desc = Ast.If (cond, [store], None); stmt_pos = pos } + + | Ast.IfLet (name, expr, then_stmts, else_opt) -> + (* Desugar `if (var name = expr) { T } else { E }` into: + var __iflet__ = expr + if (__iflet__ != null) { T } else { E } + The synthetic name is what the IR actually declares, so an outer + variable of the same name is never clobbered when the backend + hoists declarations to function scope. References to `name` inside + `T` are redirected to the synthetic name through + `ctx.iflet_aliases`, which is set up only around the lowering of + the then-branch — the else-branch sees the un-aliased name. The + codegen rule for `IRMapAccess NullLit` (and the symmetric + form for raw pointers) emits a pointer presence check, so this + lowers correctly without an extra dereference. *) + let pos = stmt.stmt_pos in + let synth = generate_temp_variable ctx ("iflet_" ^ name) in + let mk_expr ?ty d = + { Ast.expr_desc = d; expr_pos = pos; expr_type = ty; + type_checked = false; program_context = None; map_scope = None } in + lower_statement ctx + { Ast.stmt_desc = Ast.Declaration (synth, None, Some expr); stmt_pos = pos }; + let cond_val = lower_expression ctx (mk_expr ~ty:Ast.Bool + (Ast.BinaryOp + (mk_expr ?ty:expr.Ast.expr_type (Ast.Identifier synth), + Ast.Ne, + mk_expr (Ast.Literal Ast.NullLit)))) in + let collect_block stmts = + let saved = ctx.current_block in + ctx.current_block <- []; + List.iter (lower_statement ctx) stmts; + let instrs = List.rev ctx.current_block in + ctx.current_block <- saved; + instrs in + let prev_alias = Hashtbl.find_opt ctx.iflet_aliases name in + Hashtbl.replace ctx.iflet_aliases name synth; + let then_instrs = collect_block then_stmts in + (match prev_alias with + | Some p -> Hashtbl.replace ctx.iflet_aliases name p + | None -> Hashtbl.remove ctx.iflet_aliases name); + let else_instrs_opt = Option.map collect_block else_opt in + emit_instruction ctx + (make_ir_instruction (IRIf (cond_val, then_instrs, else_instrs_opt)) pos) + | Ast.For (var, start_expr, end_expr, body_stmts) -> (* Analyze the loop to determine if it's bounded or unbounded *) let loop_analysis = diff --git a/src/lexer.mll b/src/lexer.mll index cd55833..45b1ca6 100644 --- a/src/lexer.mll +++ b/src/lexer.mll @@ -122,7 +122,6 @@ | "true" -> BOOL_LIT true | "false" -> BOOL_LIT false | "null" -> NULL - | "none" -> NONE | id -> IDENTIFIER id } diff --git a/src/multi_program_analyzer.ml b/src/multi_program_analyzer.ml index fbe94be..dd7b20a 100644 --- a/src/multi_program_analyzer.ml +++ b/src/multi_program_analyzer.ml @@ -187,6 +187,10 @@ let analyze_map_usage (programs: program_def list) (global_maps: map_declaration analyze_expr_for_maps prog_name map_expr; analyze_expr_for_maps prog_name key_expr; analyze_expr_for_maps prog_name value_expr + | CompoundFieldIndexAssignment (map_expr, key_expr, _, _, value_expr) -> + analyze_expr_for_maps prog_name map_expr; + analyze_expr_for_maps prog_name key_expr; + analyze_expr_for_maps prog_name value_expr | FieldAssignment (obj_expr, _, value_expr) -> analyze_expr_for_maps prog_name obj_expr; analyze_expr_for_maps prog_name value_expr @@ -211,6 +215,12 @@ let analyze_map_usage (programs: program_def list) (global_maps: map_declaration (match else_stmts_opt with | Some else_stmts -> List.iter (analyze_stmt_for_maps prog_name) else_stmts | None -> ()) + | IfLet (_, expr, then_stmts, else_stmts_opt) -> + analyze_expr_for_maps prog_name expr; + List.iter (analyze_stmt_for_maps prog_name) then_stmts; + (match else_stmts_opt with + | Some else_stmts -> List.iter (analyze_stmt_for_maps prog_name) else_stmts + | None -> ()) | For (_, start_expr, end_expr, body_stmts) -> analyze_expr_for_maps prog_name start_expr; analyze_expr_for_maps prog_name end_expr; diff --git a/src/parse.ml b/src/parse.ml index f8d16ea..09802cb 100644 --- a/src/parse.ml +++ b/src/parse.ml @@ -96,6 +96,8 @@ let validate_ast ast = | CompoundAssignment (_, _, expr) -> validate_expr expr | CompoundIndexAssignment (map_expr, key_expr, _, value_expr) -> validate_expr map_expr && validate_expr key_expr && validate_expr value_expr + | CompoundFieldIndexAssignment (map_expr, key_expr, _, _, value_expr) -> + validate_expr map_expr && validate_expr key_expr && validate_expr value_expr | FieldAssignment (obj_expr, _, value_expr) -> validate_expr obj_expr && validate_expr value_expr | ArrowAssignment (obj_expr, _, value_expr) -> @@ -110,7 +112,11 @@ let validate_ast ast = | Return None -> true | Return (Some expr) -> validate_expr expr | If (cond, then_stmts, else_opt) -> - validate_expr cond && + validate_expr cond && + List.for_all validate_stmt then_stmts && + (match else_opt with None -> true | Some stmts -> List.for_all validate_stmt stmts) + | IfLet (_, expr, then_stmts, else_opt) -> + validate_expr expr && List.for_all validate_stmt then_stmts && (match else_opt with None -> true | Some stmts -> List.for_all validate_stmt stmts) | For (_, start, end_, body) -> diff --git a/src/parser.mly b/src/parser.mly index 1f90250..167c83d 100644 --- a/src/parser.mly +++ b/src/parser.mly @@ -46,7 +46,7 @@ %token STRING IDENTIFIER %token CHAR_LIT %token BOOL_LIT -%token NULL NONE +%token NULL /* Keywords */ %token FN EXTERN INCLUDE PIN TYPE STRUCT ENUM IMPL @@ -126,6 +126,7 @@ %type assignment_or_expression_statement %type compound_assignment_statement %type compound_index_assignment_statement +%type compound_field_index_assignment_statement %type field_assignment_statement %type arrow_assignment_statement %type index_assignment_statement @@ -313,6 +314,7 @@ statement: | index_assignment_statement { $1 } | compound_assignment_statement { $1 } | compound_index_assignment_statement { $1 } + | compound_field_index_assignment_statement { $1 } | assignment_or_expression_statement { $1 } | return_statement { $1 } | if_statement { $1 } @@ -380,6 +382,18 @@ compound_index_assignment_statement: | expression LBRACKET expression RBRACKET MODULO_ASSIGN expression { make_stmt (CompoundIndexAssignment ($1, $3, Mod, $6)) (make_pos ()) } +compound_field_index_assignment_statement: + | expression LBRACKET expression RBRACKET DOT IDENTIFIER PLUS_ASSIGN expression + { make_stmt (CompoundFieldIndexAssignment ($1, $3, $6, Add, $8)) (make_pos ()) } + | expression LBRACKET expression RBRACKET DOT IDENTIFIER MINUS_ASSIGN expression + { make_stmt (CompoundFieldIndexAssignment ($1, $3, $6, Sub, $8)) (make_pos ()) } + | expression LBRACKET expression RBRACKET DOT IDENTIFIER MULTIPLY_ASSIGN expression + { make_stmt (CompoundFieldIndexAssignment ($1, $3, $6, Mul, $8)) (make_pos ()) } + | expression LBRACKET expression RBRACKET DOT IDENTIFIER DIVIDE_ASSIGN expression + { make_stmt (CompoundFieldIndexAssignment ($1, $3, $6, Div, $8)) (make_pos ()) } + | expression LBRACKET expression RBRACKET DOT IDENTIFIER MODULO_ASSIGN expression + { make_stmt (CompoundFieldIndexAssignment ($1, $3, $6, Mod, $8)) (make_pos ()) } + return_statement: | RETURN { make_stmt (Return None) (make_pos ()) } | RETURN expression { make_stmt (Return (Some $2)) (make_pos ()) } @@ -391,6 +405,12 @@ if_statement: { make_stmt (If ($3, $6, Some $10)) (make_pos ()) } | IF LPAREN expression RPAREN LBRACE statement_list RBRACE ELSE if_statement { make_stmt (If ($3, $6, Some [$9])) (make_pos ()) } + | IF LPAREN VAR IDENTIFIER ASSIGN expression RPAREN LBRACE statement_list RBRACE + { make_stmt (IfLet ($4, $6, $9, None)) (make_pos ()) } + | IF LPAREN VAR IDENTIFIER ASSIGN expression RPAREN LBRACE statement_list RBRACE ELSE LBRACE statement_list RBRACE + { make_stmt (IfLet ($4, $6, $9, Some $13)) (make_pos ()) } + | IF LPAREN VAR IDENTIFIER ASSIGN expression RPAREN LBRACE statement_list RBRACE ELSE if_statement + { make_stmt (IfLet ($4, $6, $9, Some [$12])) (make_pos ()) } while_statement: | WHILE LPAREN expression RPAREN LBRACE statement_list RBRACE @@ -496,7 +516,6 @@ literal: | CHAR_LIT { CharLit $1 } | BOOL_LIT { BoolLit $1 } | NULL { NullLit } - | NONE { NoneLit } | LBRACKET array_init_expr RBRACKET { ArrayLit $2 } array_init_expr: diff --git a/src/safety_checker.ml b/src/safety_checker.ml index a661706..39cb2e8 100644 --- a/src/safety_checker.ml +++ b/src/safety_checker.ml @@ -242,6 +242,10 @@ let analyze_statement_bounds stmt = errors := check_array_bounds map_expr @ !errors; errors := check_array_bounds key_expr @ !errors; errors := check_array_bounds value_expr @ !errors + | CompoundFieldIndexAssignment (map_expr, key_expr, _, _, value_expr) -> + errors := check_array_bounds map_expr @ !errors; + errors := check_array_bounds key_expr @ !errors; + errors := check_array_bounds value_expr @ !errors | FieldAssignment (obj_expr, _, value_expr) -> errors := check_array_bounds obj_expr @ !errors; errors := check_array_bounds value_expr @ !errors diff --git a/src/symbol_table.ml b/src/symbol_table.ml index dd94622..de0503a 100644 --- a/src/symbol_table.ml +++ b/src/symbol_table.ml @@ -688,6 +688,10 @@ and process_statement table stmt = process_expression table map_expr; process_expression table key_expr; process_expression table value_expr + | CompoundFieldIndexAssignment (map_expr, key_expr, _field, _, value_expr) -> + process_expression table map_expr; + process_expression table key_expr; + process_expression table value_expr | FieldAssignment (obj_expr, _field, value_expr) -> process_expression table obj_expr; process_expression table value_expr @@ -720,6 +724,21 @@ and process_statement table stmt = List.iter (process_statement table_with_else) else_stmts; let _ = exit_scope table_with_else in () | None -> ()) + + | IfLet (name, expr, then_stmts, else_opt) -> + process_expression table expr; + let table_with_block = enter_scope table BlockScope in + (* Bind `name` only inside the truthy branch. + Type is unknown at this stage; type checker fills the precise type. *) + add_variable table_with_block name U32 stmt.stmt_pos; + List.iter (process_statement table_with_block) then_stmts; + let _ = exit_scope table_with_block in + (match else_opt with + | Some else_stmts -> + let table_with_else = enter_scope table BlockScope in + List.iter (process_statement table_with_else) else_stmts; + let _ = exit_scope table_with_else in () + | None -> ()) | For (var_name, start_expr, end_expr, body) -> process_expression table start_expr; diff --git a/src/type_checker.ml b/src/type_checker.ml index 8a95a99..b43b6c6 100644 --- a/src/type_checker.ml +++ b/src/type_checker.ml @@ -94,6 +94,8 @@ and typed_stmt_desc = | TAssignment of string * typed_expr | TCompoundAssignment of string * binary_op * typed_expr (* var op= expr *) | TCompoundIndexAssignment of typed_expr * typed_expr * binary_op * typed_expr (* map[key] op= expr *) + | TCompoundFieldIndexAssignment of typed_expr * typed_expr * string * binary_op * typed_expr + (* map[key].field op= expr *) | TFieldAssignment of typed_expr * string * typed_expr (* object, field, value *) | TArrowAssignment of typed_expr * string * typed_expr (* pointer, field, value *) | TIndexAssignment of typed_expr * typed_expr * typed_expr @@ -101,6 +103,8 @@ and typed_stmt_desc = | TConstDeclaration of string * bpf_type * typed_expr | TReturn of typed_expr option | TIf of typed_expr * typed_statement list * typed_statement list option + | TIfLet of string * bpf_type * typed_expr * typed_statement list * typed_statement list option + (* name, bound_type (type of `name` inside then-branch), source expr, then, else *) | TFor of string * typed_expr * typed_expr * typed_statement list | TForIter of string * string * typed_expr * typed_statement list | TWhile of typed_expr * typed_statement list @@ -404,7 +408,6 @@ let get_literal_type lit = | CharLit _ -> Char | BoolLit _ -> Bool | NullLit -> Pointer U32 - | NoneLit -> NoneType | ArrayLit _ -> U32 (* Nested arrays default to u32 *) (** Helper function to check type equality for array literals *) @@ -431,7 +434,6 @@ let type_check_literal lit pos = | CharLit _ -> Char | BoolLit _ -> Bool | NullLit -> Null (* null literal - can unify with any pointer or function type *) - | NoneLit -> NoneType (* none literal represents missing/absent values *) | ArrayLit init_style -> (* Handle enhanced array literal type checking *) (match init_style with @@ -484,7 +486,6 @@ let type_of_literal lit = | CharLit _ -> Char | BoolLit _ -> Bool | NullLit -> Pointer U32 - | NoneLit -> NoneType | ArrayLit init_style -> (* Handle enhanced array literal type checking *) (match init_style with @@ -697,6 +698,16 @@ let rec extract_block_return_type stmts arm_pos = | TIf (_, _, None) -> (* If without else - this doesn't work as a return value *) type_error "If statement without else cannot be used as return value in match arm" arm_pos + | TIfLet (_, _, _, then_stmts, Some else_stmts) -> + let then_type = extract_block_return_type then_stmts arm_pos in + let else_type = extract_block_return_type else_stmts arm_pos in + (match unify_types then_type else_type with + | Some unified_type -> unified_type + | None -> type_error ("If-let branches have incompatible types: " ^ + string_of_bpf_type then_type ^ " vs " ^ + string_of_bpf_type else_type) arm_pos) + | TIfLet (_, _, _, _, None) -> + type_error "If-let without else cannot be used as return value in match arm" arm_pos | _ -> type_error "Block arms must end with a return statement, expression, or if-else statement" arm_pos in @@ -1082,29 +1093,6 @@ and type_check_binary_op ctx left op right pos = (* Null comparisons - any type can be compared with null *) | Null, _ | _, Null -> Bool (* Direct null comparisons *) | _, Pointer _ | Pointer _, _ -> Bool (* Pointer comparisons (legacy) *) - (* None comparisons - allow with map access expressions or variables that could contain map results *) - | NoneType, _ | _, NoneType -> - (* Check if at least one operand is a map access or could reasonably be a map result *) - let has_map_related_value = - (match left.expr_desc with - | ArrayAccess (map_expr, _) -> - (match map_expr.expr_desc with - | Identifier map_name -> Hashtbl.mem ctx.maps map_name - | _ -> false) - | Identifier _ -> true (* Variables can contain map lookup results *) - | _ -> false) || - (match right.expr_desc with - | ArrayAccess (map_expr, _) -> - (match map_expr.expr_desc with - | Identifier map_name -> Hashtbl.mem ctx.maps map_name - | _ -> false) - | Identifier _ -> true (* Variables can contain map lookup results *) - | _ -> false) - in - if has_map_related_value then - Bool - else - type_error "'none' can only be compared with map access expressions or variables that may contain map results" pos | _ -> (match unify_types resolved_left_type resolved_right_type with | Some _ -> Bool @@ -1557,10 +1545,6 @@ and type_check_statement ctx stmt = | Assignment (name, expr) -> let typed_expr = type_check_expression ctx expr in - (* Check if trying to assign none to a variable *) - (match typed_expr.texpr_type with - | NoneType -> type_error ("'none' cannot be assigned to variables. It can only be used in comparisons with map lookup results.") stmt.stmt_pos - | _ -> ()); (* Check if the variable is const by looking it up in the symbol table *) (match Symbol_table.lookup_symbol ctx.symbol_table name with | Some symbol when Symbol_table.is_const_variable symbol -> @@ -1808,7 +1792,57 @@ and type_check_statement ctx stmt = type_error ("Operator " ^ string_of_binary_op op ^ " not supported for type " ^ string_of_bpf_type element_type) stmt.stmt_pos) | _ -> type_error ("Compound index assignment can only be used on maps or arrays") stmt.stmt_pos)) - + + | CompoundFieldIndexAssignment (map_expr, key_expr, field, op, value_expr) -> + let typed_key = type_check_expression ctx key_expr in + let typed_value = type_check_expression ctx value_expr in + let map_name = match map_expr.expr_desc with + | Identifier name when Hashtbl.mem ctx.maps name -> name + | _ -> type_error "Compound field-index assignment requires a map identifier" stmt.stmt_pos + in + let map_decl = Hashtbl.find ctx.maps map_name in + (* Key type *) + let resolved_key_type = resolve_user_type ctx map_decl.ast_key_type in + let resolved_typed_key_type = resolve_user_type ctx typed_key.texpr_type in + (match unify_types resolved_key_type resolved_typed_key_type with + | Some _ -> () + | None -> type_error "Map key type mismatch" stmt.stmt_pos); + (* Resolve the map's value type to a struct *) + let resolved_value_type = resolve_user_type ctx map_decl.ast_value_type in + let struct_name = match resolved_value_type with + | Struct n | UserType n -> n + | _ -> type_error "map[k].field op= rhs requires the map's value type to be a struct" stmt.stmt_pos + in + let fields = + try + (match Hashtbl.find ctx.types struct_name with + | StructDef (_, fs, _) -> fs + | _ -> type_error (struct_name ^ " is not a struct") stmt.stmt_pos) + with Not_found -> type_error ("Undefined struct: " ^ struct_name) stmt.stmt_pos + in + let field_type = + try List.assoc field fields + with Not_found -> + type_error ("Field not found: " ^ field ^ " in struct " ^ struct_name) stmt.stmt_pos + in + (* rhs must match field type *) + let resolved_field_type = resolve_user_type ctx field_type in + let resolved_typed_value_type = resolve_user_type ctx typed_value.texpr_type in + (match unify_types resolved_field_type resolved_typed_value_type with + | Some _ -> () + | None -> type_error ("Field value type mismatch for " ^ field) stmt.stmt_pos); + (* op must be valid for the field type *) + (match op, resolved_field_type with + | (Add | Sub | Mul | Div | Mod), (U8 | U16 | U32 | U64 | I8 | I16 | I32 | I64) -> + let typed_map = { texpr_desc = TIdentifier map_name; + texpr_type = Map (map_decl.ast_key_type, map_decl.ast_value_type, map_decl.ast_map_type, map_decl.max_entries); + texpr_pos = map_expr.expr_pos } in + { tstmt_desc = TCompoundFieldIndexAssignment (typed_map, typed_key, field, op, typed_value); + tstmt_pos = stmt.stmt_pos } + | _, _ -> + type_error ("Operator " ^ string_of_binary_op op ^ + " not supported for field type " ^ string_of_bpf_type resolved_field_type) stmt.stmt_pos) + | Declaration (name, type_opt, expr_opt) -> let typed_expr_opt = Option.map (type_check_expression ctx) expr_opt in @@ -1818,12 +1852,6 @@ and type_check_statement ctx stmt = type_error ("Maps cannot be assigned to variables") stmt.stmt_pos | _ -> ()); - (* Check if trying to assign none to a variable *) - (match typed_expr_opt with - | Some typed_expr when (match typed_expr.texpr_type with NoneType -> true | _ -> false) -> - type_error ("'none' cannot be assigned to variables. It can only be used in comparisons with map lookup results.") stmt.stmt_pos - | _ -> ()); - let var_type = match type_opt with | Some declared_type -> let resolved_declared_type = resolve_user_type ctx declared_type in @@ -1857,11 +1885,6 @@ and type_check_statement ctx stmt = | Map (_, _, _, _) -> type_error ("Maps cannot be assigned to const variables") stmt.stmt_pos | _ -> ()); - (* Check if trying to assign none to a const *) - (match typed_expr.texpr_type with - | NoneType -> type_error ("'none' cannot be assigned to variables. It can only be used in comparisons with map lookup results.") stmt.stmt_pos - | _ -> ()); - (* Validate that the expression is a compile-time constant (literals and negated literals) *) let const_value = match typed_expr.texpr_desc with | TLiteral lit -> lit @@ -2028,6 +2051,49 @@ and type_check_statement ctx stmt = let typed_then = List.map (type_check_statement ctx) then_stmts in let typed_else = Option.map (List.map (type_check_statement ctx)) else_opt in { tstmt_desc = TIf (typed_cond, typed_then, typed_else); tstmt_pos = stmt.stmt_pos } + + | IfLet (name, expr, then_stmts, else_opt) -> + (* `if (var name = expr) { ... }` — bind `name` only inside then-branch. + The bound type matches what `var name = expr` would normally + produce: the value type for map access (auto-deref via + IRMapAccess), and the pointer type for raw pointer expressions. + We restrict the RHS to "presence-producing" expressions, since + the construct's truthiness is defined as "expr produced a present + value" — i.e., a map hit or a non-null pointer. Allowing arbitrary + scalar / struct RHS would let the codegen emit `x != NULL` + against a non-pointer value (clang -Wpointer-integer-compare, + invalid C for struct types) and would let the evaluator's general + truthy-falsy rule diverge from the codegen's pointer presence + check. The legal shapes are: + - `m[k]` where `m` is a known map (auto-deref'd value type at + this layer, but underlying-pointer-checked at codegen) + - any expression of pointer type. *) + let typed_expr = type_check_expression ctx expr in + let bound_type = typed_expr.texpr_type in + let is_map_access_rhs = match expr.expr_desc with + | ArrayAccess ({ expr_desc = Identifier mn; _ }, _) -> + Hashtbl.mem ctx.maps mn + | _ -> false + in + let is_pointer_rhs = match bound_type with + | Pointer _ -> true + | _ -> false + in + if not (is_map_access_rhs || is_pointer_rhs) then + type_error + ("`if (var " ^ name ^ " = expr)` requires expr to be a map access " ^ + "(`m[k]`) or a pointer-typed expression; got " ^ + string_of_bpf_type bound_type) + stmt.stmt_pos; + let saved = Hashtbl.find_opt ctx.variables name in + Hashtbl.replace ctx.variables name bound_type; + let typed_then = List.map (type_check_statement ctx) then_stmts in + (match saved with + | Some t -> Hashtbl.replace ctx.variables name t + | None -> Hashtbl.remove ctx.variables name); + let typed_else = Option.map (List.map (type_check_statement ctx)) else_opt in + { tstmt_desc = TIfLet (name, bound_type, typed_expr, typed_then, typed_else); + tstmt_pos = stmt.stmt_pos } | For (var, start, end_, body) -> if !loop_depth > 0 then @@ -2570,6 +2636,8 @@ and typed_stmt_to_stmt tstmt = | TCompoundAssignment (name, op, expr) -> CompoundAssignment (name, op, typed_expr_to_expr expr) | TCompoundIndexAssignment (map_expr, key_expr, op, value_expr) -> CompoundIndexAssignment (typed_expr_to_expr map_expr, typed_expr_to_expr key_expr, op, typed_expr_to_expr value_expr) + | TCompoundFieldIndexAssignment (map_expr, key_expr, field, op, value_expr) -> + CompoundFieldIndexAssignment (typed_expr_to_expr map_expr, typed_expr_to_expr key_expr, field, op, typed_expr_to_expr value_expr) | TFieldAssignment (obj_expr, field, value_expr) -> FieldAssignment (typed_expr_to_expr obj_expr, field, typed_expr_to_expr value_expr) | TArrowAssignment (obj_expr, field, value_expr) -> @@ -2579,10 +2647,15 @@ and typed_stmt_to_stmt tstmt = | TDeclaration (name, typ, expr_opt) -> Declaration (name, Some typ, Option.map typed_expr_to_expr expr_opt) | TConstDeclaration (name, typ, expr) -> ConstDeclaration (name, Some typ, typed_expr_to_expr expr) | TReturn expr_opt -> Return (Option.map typed_expr_to_expr expr_opt) - | TIf (cond, then_stmts, else_opt) -> - If (typed_expr_to_expr cond, + | TIf (cond, then_stmts, else_opt) -> + If (typed_expr_to_expr cond, List.map typed_stmt_to_stmt then_stmts, Option.map (List.map typed_stmt_to_stmt) else_opt) + | TIfLet (name, _bound_type, expr, then_stmts, else_opt) -> + IfLet (name, + typed_expr_to_expr expr, + List.map typed_stmt_to_stmt then_stmts, + Option.map (List.map typed_stmt_to_stmt) else_opt) | TFor (var, start, end_, body) -> For (var, typed_expr_to_expr start, typed_expr_to_expr end_, List.map typed_stmt_to_stmt body) | TForIter (index_var, value_var, iterable, body) -> @@ -3237,6 +3310,13 @@ and populate_multi_program_context ast multi_prog_analysis = (match map_expr.program_context with | Some ctx -> map_expr.program_context <- Some { ctx with data_flow_direction = Some Write } | None -> ()) + | CompoundFieldIndexAssignment (map_expr, key_expr, _, _, value_expr) -> + enhance_expr prog_type map_expr; + enhance_expr prog_type key_expr; + enhance_expr prog_type value_expr; + (match map_expr.program_context with + | Some ctx -> map_expr.program_context <- Some { ctx with data_flow_direction = Some Write } + | None -> ()) | FieldAssignment (obj_expr, _, value_expr) -> enhance_expr prog_type obj_expr; enhance_expr prog_type value_expr @@ -3266,6 +3346,12 @@ and populate_multi_program_context ast multi_prog_analysis = (match else_stmts_opt with | Some else_stmts -> List.iter (enhance_stmt prog_type) else_stmts | None -> ()) + | IfLet (_, expr, then_stmts, else_stmts_opt) -> + enhance_expr prog_type expr; + List.iter (enhance_stmt prog_type) then_stmts; + (match else_stmts_opt with + | Some else_stmts -> List.iter (enhance_stmt prog_type) else_stmts + | None -> ()) | For (_, start_expr, end_expr, body_stmts) -> enhance_expr prog_type start_expr; enhance_expr prog_type end_expr; @@ -3336,6 +3422,10 @@ and populate_multi_program_context ast multi_prog_analysis = enhance_userspace_expr map_expr; enhance_userspace_expr key_expr; enhance_userspace_expr value_expr + | CompoundFieldIndexAssignment (map_expr, key_expr, _, _, value_expr) -> + enhance_userspace_expr map_expr; + enhance_userspace_expr key_expr; + enhance_userspace_expr value_expr | FieldAssignment (obj_expr, _, value_expr) -> enhance_userspace_expr obj_expr; enhance_userspace_expr value_expr @@ -3360,6 +3450,12 @@ and populate_multi_program_context ast multi_prog_analysis = (match else_stmts_opt with | Some else_stmts -> List.iter enhance_userspace_stmt_inner else_stmts | None -> ()) + | IfLet (_, expr, then_stmts, else_stmts_opt) -> + enhance_userspace_expr expr; + List.iter enhance_userspace_stmt_inner then_stmts; + (match else_stmts_opt with + | Some else_stmts -> List.iter enhance_userspace_stmt_inner else_stmts + | None -> ()) | For (_, start_expr, end_expr, body_stmts) -> enhance_userspace_expr start_expr; enhance_userspace_expr end_expr; diff --git a/src/userspace_codegen.ml b/src/userspace_codegen.ml index 0c07f08..3cbb319 100644 --- a/src/userspace_codegen.ml +++ b/src/userspace_codegen.ml @@ -1219,7 +1219,6 @@ let determine_global_var_section (global_var : ir_global_variable) = | IRLiteral (Ast.IntLit (Ast.Signed64 0L, _)) -> "bss" (* Zero-initialized integers go to .bss *) | IRLiteral (Ast.BoolLit false) -> "bss" (* False booleans go to .bss *) | IRLiteral (Ast.NullLit) -> "bss" (* Null pointers go to .bss *) - | IRLiteral (Ast.NoneLit) -> "bss" (* None values go to .bss *) | IRLiteral (Ast.IntLit (_, _)) -> "data" (* Non-zero integers go to .data *) | IRLiteral (Ast.BoolLit true) -> "data" (* True booleans go to .data *) | IRLiteral (Ast.StringLit _) -> "data" (* String literals go to .data *) @@ -1277,8 +1276,7 @@ let rec generate_c_value_from_ir ?(auto_deref_map_access=false) ctx ir_value = | IRLiteral (CharLit c) -> sprintf "'%c'" c | IRLiteral (BoolLit b) -> if b then "true" else "false" | IRLiteral (NullLit) -> "NULL" - | IRLiteral (NoneLit) -> "/* none */" - | IRLiteral (StringLit s) -> + | IRLiteral (StringLit s) -> (* Generate simple string literal for userspace *) sprintf "\"%s\"" s | IRLiteral (ArrayLit init_style) -> @@ -1292,7 +1290,6 @@ let rec generate_c_value_from_ir ?(auto_deref_map_access=false) ctx ir_value = | Ast.CharLit c -> sprintf "'%c'" c | Ast.StringLit s -> sprintf "\"%s\"" s | Ast.NullLit -> "NULL" - | Ast.NoneLit -> "/* none */" | Ast.ArrayLit _ -> "{...}" (* nested arrays simplified *) in sprintf "{%s}" fill_str @@ -1303,7 +1300,6 @@ let rec generate_c_value_from_ir ?(auto_deref_map_access=false) ctx ir_value = | Ast.BoolLit b -> if b then "true" else "false" | Ast.StringLit s -> sprintf "\"%s\"" s | Ast.NullLit -> "NULL" - | Ast.NoneLit -> "/* none */" | Ast.ArrayLit _ -> "{...}" (* nested arrays simplified *) ) elems in sprintf "{%s}" (String.concat ", " elem_strs)) @@ -1406,26 +1402,27 @@ let generate_c_expression_from_ir ctx ir_expr = let index_str = generate_c_value_from_ir ctx right_val in sprintf "%s[%s]" array_str index_str | _ -> - (* Check for none comparisons first *) + (* `null` comparisons against a map-access lower to a presence + check against the underlying lookup pointer (or the pointer + value directly), avoiding an extra dereference. *) + let is_absence_lit = function + | IRLiteral (Ast.NullLit) -> true + | _ -> false + in + let pointer_str v = + match v.value_desc with + | IRMapAccess (_, _, _) -> generate_c_value_from_ir ~auto_deref_map_access:false ctx v + | _ -> generate_c_value_from_ir ctx v + in (match left_val.value_desc, op, right_val.value_desc with - | _, IREq, IRLiteral (Ast.NoneLit) - | IRLiteral (Ast.NoneLit), IREq, _ -> - (* Comparison with none: check if pointer is NULL *) - let non_none_val = if left_val.value_desc = IRLiteral (Ast.NoneLit) then right_val else left_val in - (* For IRMapAccess, use the underlying pointer directly for NULL check *) - let val_str = (match non_none_val.value_desc with - | IRMapAccess (_, _, _) -> generate_c_value_from_ir ~auto_deref_map_access:false ctx non_none_val - | _ -> generate_c_value_from_ir ctx non_none_val) in - sprintf "(%s == NULL)" val_str - | _, IRNe, IRLiteral (Ast.NoneLit) - | IRLiteral (Ast.NoneLit), IRNe, _ -> - (* Not-equal comparison with none: check if pointer is not NULL *) - let non_none_val = if left_val.value_desc = IRLiteral (Ast.NoneLit) then right_val else left_val in - (* For IRMapAccess, use the underlying pointer directly for NULL check *) - let val_str = (match non_none_val.value_desc with - | IRMapAccess (_, _, _) -> generate_c_value_from_ir ~auto_deref_map_access:false ctx non_none_val - | _ -> generate_c_value_from_ir ctx non_none_val) in - sprintf "(%s != NULL)" val_str + | _, IREq, _ when is_absence_lit right_val.value_desc -> + sprintf "(%s == NULL)" (pointer_str left_val) + | _, IREq, _ when is_absence_lit left_val.value_desc -> + sprintf "(%s == NULL)" (pointer_str right_val) + | _, IRNe, _ when is_absence_lit right_val.value_desc -> + sprintf "(%s != NULL)" (pointer_str left_val) + | _, IRNe, _ when is_absence_lit left_val.value_desc -> + sprintf "(%s != NULL)" (pointer_str right_val) | _ -> (* Regular binary operation - auto-dereference map access for operands *) let left_str = (match left_val.value_desc with @@ -1817,7 +1814,15 @@ let rec generate_c_instruction_from_ir ctx instruction = let decl_str = generate_c_declaration typ c_var_name in (match init_expr_opt with | Some init_expr -> - let init_str = generate_c_expression_from_ir ctx init_expr in + let init_str = + (match typ, init_expr.expr_desc with + | IRPointer _, IRValue src_val + when (match src_val.value_desc with IRMapAccess _ -> true | _ -> false) -> + (* Pointer-typed variable initialized from a map lookup: keep the pointer. *) + generate_c_value_from_ir ~auto_deref_map_access:false ctx src_val + | _ -> + generate_c_expression_from_ir ctx init_expr) + in sprintf "%s = %s;" decl_str init_str | None -> sprintf "%s;" decl_str)) diff --git a/tests/dune b/tests/dune index 25142e2..198cb2e 100644 --- a/tests/dune +++ b/tests/dune @@ -90,6 +90,11 @@ (modules test_compound_index_assignment) (libraries kernelscript alcotest test_utils str)) +(executable + (name test_iflet) + (modules test_iflet) + (libraries kernelscript alcotest test_utils str)) + (executable (name test_dynptr_bridge) (modules test_dynptr_bridge) @@ -454,6 +459,7 @@ test_map_operations.exe test_evaluator.exe test_compound_index_assignment.exe + test_iflet.exe test_dynptr_bridge.exe test_global_var_ordering.exe test_string_to_array_unification.exe @@ -590,6 +596,10 @@ (alias runtest) (action (run ./test_compound_index_assignment.exe))) +(rule + (alias runtest) + (action (run ./test_iflet.exe))) + (rule (alias runtest) (action (run ./test_dynptr_bridge.exe))) diff --git a/tests/test_compound_index_assignment.ml b/tests/test_compound_index_assignment.ml index 12cf130..00b930f 100644 --- a/tests/test_compound_index_assignment.ml +++ b/tests/test_compound_index_assignment.ml @@ -404,13 +404,87 @@ var packet_counts : hash(1024) let (typed_ast, _) = type_check_and_annotate_ast_with_builtins ast in let symbol_table = Test_utils.Helpers.create_test_symbol_table ast in let ir_multi_program = Kernelscript.Ir_generator.generate_ir typed_ast symbol_table "rate_limiter" in - + (* Check that compilation was successful *) check bool "end-to-end compilation successful" true (List.length (get_programs ir_multi_program) > 0); print_endline "✓ End-to-end compilation test passed" with | e -> failwith ("End-to-end compilation failed: " ^ Printexc.to_string e) +(** Test 14: Compound assignment on a struct field accessed via map index *) +let test_map_index_field_compound_assignment () = + let source = {| +struct Stats { count: u64, bytes: u64 } +var stats : hash(1024) + +@xdp fn probe(ctx: *xdp_md) -> xdp_action { + stats[1].count += 1 + return XDP_PASS +} +|} in + let ast = parse_string source in + let (typed_ast, _) = type_check_and_annotate_ast_with_builtins ast in + let symbol_table = Test_utils.Helpers.create_test_symbol_table ast in + let ir_multi_program = Kernelscript.Ir_generator.generate_ir typed_ast symbol_table "probe" in + check bool "map[k].field compound assign compiles" true + (List.length (get_programs ir_multi_program) > 0); + print_endline "✓ map[k].field += rhs compiles end-to-end" + +(** Test 15: Codegen for `m[k].field op= rhs` produces the expected eBPF C + shape. This locks in the Phase 2 codegen path: + + (a) The synthetic pointer binding for the lowered IfLet is declared + with a pointer type (`struct Stats* __cidx_field_N`) and is + initialised from the lookup pointer directly — *not* via the + deref-load statement-expression. A regression to the old shape + produced a `struct Stats* x = ({ struct Stats __val = ...; __val; })` + that fails clang -target bpf with a value-to-pointer mismatch. + + (b) The body emits a presence-checked `ptr->field = ptr->field op rhs` + using the underlying map lookup pointer. + + (c) The field's type width matches the struct definition (u64) — i.e. + the codegen does not default to u32 because the synthesized + FieldAccess loses its `expr_type` annotation. *) +let test_map_index_field_compound_codegen () = + let source = {| +struct Stats { count: u64, bytes: u64 } +var stats : hash(1024) + +@xdp fn probe(ctx: *xdp_md) -> xdp_action { + stats[1].count += 1 + return XDP_PASS +} +|} in + let ast = parse_string source in + let (typed_ast, _) = type_check_and_annotate_ast_with_builtins ast in + let symbol_table = Test_utils.Helpers.create_test_symbol_table ast in + let ir_multi_program = + Kernelscript.Ir_generator.generate_ir typed_ast symbol_table "probe" in + let c = Kernelscript.Ebpf_c_codegen.generate_c_multi_program ir_multi_program in + let contains s = + try let _ = Str.search_forward (Str.regexp_string s) c 0 in true + with Not_found -> false in + (* (a) pointer-typed synthetic binding initialised from the lookup pointer. + The Phase 2 desugaring emits a plain `var __cidx_field_ = m[k]` + (the synthetic name is fresh by construction, so the IfLet alpha- + rename machinery is not needed and is bypassed). The codegen then + produces `struct Stats* __cidx_field_ = __map_lookup_` via the + pointer-from-map-access path in IRVariableDecl. *) + check bool "synthetic binding declared as a struct pointer" true + (contains "struct Stats* __cidx_field_"); + let bad_value_init = contains + "struct Stats* __cidx_field_0 = ({ struct Stats __val" in + check bool "synthetic binding does NOT use deref-load init" false bad_value_init; + (* (b) presence-checked in-place mutation *) + check bool "single map lookup" true (contains "bpf_map_lookup_elem(&stats"); + check bool "presence check" true (contains "!= NULL"); + check bool "ptr->count write" true (contains "->count ="); + (* (c) field width is u64, not the IRU32 default *) + check bool "field access width is u64" true + (contains "__u64 __field_access_"); + print_endline "✓ map[k].field += rhs codegen shape locked in" + let compound_index_assignment_tests = [ "basic_parsing", `Quick, test_basic_parsing; "all_operators_parsing", `Quick, test_all_operators_parsing; @@ -425,6 +499,8 @@ let compound_index_assignment_tests = [ "ir_generation", `Quick, test_ir_generation; "ir_instruction_ordering", `Quick, test_ir_instruction_ordering; "end_to_end_compilation", `Quick, test_end_to_end_compilation; + "map_index_field_compound_assignment", `Quick, test_map_index_field_compound_assignment; + "map_index_field_compound_codegen", `Quick, test_map_index_field_compound_codegen; ] let () = diff --git a/tests/test_enum.ml b/tests/test_enum.ml index 2403dd2..58073cc 100644 --- a/tests/test_enum.ml +++ b/tests/test_enum.ml @@ -557,7 +557,7 @@ let test_enum_array_index () = fn test_enum_index() -> u32 { var proto = TCP var count = protocol_stats[proto] - if (count != none) { + if (count != null) { return count } else { return 0 diff --git a/tests/test_iflet.ml b/tests/test_iflet.ml new file mode 100644 index 0000000..d964a03 --- /dev/null +++ b/tests/test_iflet.ml @@ -0,0 +1,375 @@ +(* + * Copyright 2026 Multikernel Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *) + +(** Tests for the `if (var x = expr)` declaration-as-condition statement. *) + +open Kernelscript.Ast +open Kernelscript.Parse +open Alcotest + +let contains_substr str substr = + try let _ = Str.search_forward (Str.regexp_string substr) str 0 in true + with Not_found -> false + +let typecheck source = + let ast = parse_string source in + let symbol_table = Test_utils.Helpers.create_test_symbol_table ast in + let (typed_ast, _) = + Kernelscript.Type_checker.type_check_and_annotate_ast + ~symbol_table:(Some symbol_table) ast in + (ast, symbol_table, typed_ast) + +let codegen_ebpf source = + let (_ast, symbol_table, typed_ast) = typecheck source in + let ir = Kernelscript.Ir_generator.generate_ir typed_ast symbol_table "test" in + Kernelscript.Ebpf_c_codegen.generate_c_multi_program ir + +let extract_first_stmt source = + let ast = parse_string source in + let attr_func = + List.find (function AttributedFunction _ -> true | _ -> false) ast in + match attr_func with + | AttributedFunction af -> List.nth af.attr_function.func_body 0 + | _ -> failwith "no attributed function" + +(** 1. Parse: bare `if (var x = ...)` produces an IfLet AST node. *) +let test_parse_iflet_no_else () = + let source = {| +var counters : hash(1024) + +@xdp fn probe(ctx: *xdp_md) -> xdp_action { + if (var c = counters[1]) { + return XDP_DROP + } + return XDP_PASS +} +|} in + let stmt = extract_first_stmt source in + match stmt.stmt_desc with + | IfLet (name, _, _, None) -> + check string "binding name" "c" name + | _ -> fail "expected IfLet without else" + +(** 2. Parse: `if (var x = ...) { } else { }` round-trips with else. *) +let test_parse_iflet_with_else () = + let source = {| +var counters : hash(1024) + +@xdp fn probe(ctx: *xdp_md) -> xdp_action { + if (var c = counters[1]) { + return XDP_DROP + } else { + return XDP_PASS + } +} +|} in + let stmt = extract_first_stmt source in + match stmt.stmt_desc with + | IfLet (_, _, _, Some _) -> () + | _ -> fail "expected IfLet with else" + +(** 3. Parse: `else if (var ...)` chains via nested IfLet. *) +let test_parse_iflet_else_iflet () = + let source = {| +var a : hash(1024) +var b : hash(1024) + +@xdp fn probe(ctx: *xdp_md) -> xdp_action { + if (var x = a[1]) { + return XDP_DROP + } else if (var y = b[2]) { + return XDP_PASS + } + return XDP_PASS +} +|} in + let stmt = extract_first_stmt source in + match stmt.stmt_desc with + | IfLet (_, _, _, Some [{ stmt_desc = IfLet _; _ }]) -> () + | _ -> fail "expected outer IfLet whose else is a single IfLet" + +(** 4. Type-check: struct-map binding succeeds; field access in body works. *) +let test_typecheck_struct_binding () = + let source = {| +struct Stats { count: u64, bytes: u64 } +var stats : hash(1024) + +@xdp fn probe(ctx: *xdp_md) -> xdp_action { + if (var s = stats[1]) { + s.count = s.count + 1 + s.bytes = s.bytes + 100 + } + return XDP_PASS +} +|} in + let _ = typecheck source in + () + +(** 5. Type-check: scalar-map binding succeeds; value used as a value in body. *) +let test_typecheck_scalar_binding () = + let source = {| +var counters : hash(1024) + +@xdp fn probe(ctx: *xdp_md) -> xdp_action { + if (var c = counters[1]) { + if (c > 100) { + return XDP_DROP + } + } + return XDP_PASS +} +|} in + let _ = typecheck source in + () + +(** 6. Reject: binding referenced from the else-branch. *) +let test_reject_binding_in_else () = + let source = {| +var counters : hash(1024) + +@xdp fn probe(ctx: *xdp_md) -> xdp_action { + if (var c = counters[1]) { + return XDP_PASS + } else { + var leaked : u64 = c + } + return XDP_PASS +} +|} in + try + let _ = typecheck source in + fail "expected rejection of binding leak into else-branch" + with + | Kernelscript.Symbol_table.Symbol_error _ -> () + | Kernelscript.Type_checker.Type_error _ -> () + +(** 7. Reject: binding referenced after the if-statement (no outer shadow). *) +let test_reject_binding_after_if () = + let source = {| +var counters : hash(1024) + +@xdp fn probe(ctx: *xdp_md) -> xdp_action { + if (var c = counters[1]) { + return XDP_PASS + } + var leaked : u64 = c + return XDP_PASS +} +|} in + try + let _ = typecheck source in + fail "expected rejection of binding leak past the if-statement" + with + | Kernelscript.Symbol_table.Symbol_error _ -> () + | Kernelscript.Type_checker.Type_error _ -> () + +(** 8. Codegen (struct map): single lookup + presence check + in-place mutation + with no manual write-back. *) +let test_codegen_struct_in_place () = + let source = {| +struct Stats { count: u64, bytes: u64 } +var stats : hash(1024) + +@xdp fn probe(ctx: *xdp_md) -> xdp_action { + if (var s = stats[1]) { + s.count = s.count + 1 + } + return XDP_PASS +} +|} in + let c = codegen_ebpf source in + check bool "single map lookup" true (contains_substr c "bpf_map_lookup_elem(&stats"); + check bool "presence check" true (contains_substr c "!= NULL"); + check bool "in-place ptr->field write" true + (contains_substr c "->count ="); + (* In-place mutation should mean no bpf_map_update_elem in the truthy branch. + The else branch is omitted in the source, so there should be zero updates. *) + let has_update = + try let _ = Str.search_forward + (Str.regexp_string "bpf_map_update_elem(&stats") c 0 in true + with Not_found -> false in + check bool "no manual write-back update" false has_update + +(** 9. Codegen (scalar map): the binding holds the dereffed value, and the + presence check uses the underlying lookup pointer. *) +let test_codegen_scalar_value_binding () = + let source = {| +var counters : hash(1024) + +@xdp fn probe(ctx: *xdp_md) -> xdp_action { + if (var c = counters[1]) { + if (c > 100) { + return XDP_DROP + } + } + return XDP_PASS +} +|} in + let c = codegen_ebpf source in + (* The IfLet binding is alpha-renamed to a fresh synthetic name during IR + lowering (see `subst_ident_stmts` in ir_generator.ml) so that an outer + variable of the same name is not silently clobbered when the backend + hoists declarations to function scope. The synthetic name has the + form `__iflet__`. *) + check bool "scalar binding declared as value, not pointer" true + (contains_substr c "__u64 __iflet_c_"); + check bool "binding init uses the dereffed value statement-expression" true + (contains_substr c "__val = *("); + check bool "presence check on the underlying lookup pointer" true + (contains_substr c "!= NULL") + +(** 10. Codegen (struct map, end-to-end shape): the binding is declared with + the value type (the type-checker auto-derefs `m[k]` to the struct + value), but the field operations in the body lower to in-place + mutation through the underlying lookup pointer rather than through + the local. The local is therefore dead — clang elides it — but its + declaration is still syntactically a value, not a pointer. + + Concretely the previous codegen shape was, for user-written code: + struct Stats* __map_lookup_N; + __map_lookup_N = bpf_map_lookup_elem(&stats, &k); + struct Stats s = ({ struct Stats __val = {0}; + if (__map_lookup_N) { __val = *(__map_lookup_N); } + __val; }); + if (__map_lookup_N != NULL) { + ... __map_lookup_N->count = ... ; + } + Phase 2 only changed the synthetic-pointer-binding path (used by the + lowered `m[k].field op= rhs`); user-written IfLet still produces the + value-typed local above. Pinning that here so any future change to + the typing rule is intentional. *) +let test_codegen_struct_value_binding_shape () = + let source = {| +struct Stats { count: u64 } +var stats : hash(1024) + +@xdp fn probe(ctx: *xdp_md) -> xdp_action { + if (var s = stats[1]) { + s.count = s.count + 1 + } + return XDP_PASS +} +|} in + let c = codegen_ebpf source in + (* Binding is alpha-renamed to `__iflet_s_` — see the comment on + `test_codegen_scalar_value_binding` for why. *) + check bool "binding declared with value type, not pointer" true + (contains_substr c "struct Stats __iflet_s_"); + check bool "value-typed binding uses deref-load init" true + (contains_substr c "struct Stats __val"); + check bool "field write goes through the underlying lookup pointer" true + (contains_substr c "->count =") + +(** 11a. Reject: int-literal RHS — `if (var x = 5)` is not a presence check. + The construct only makes sense when the RHS is a map access (auto- + deref'd to a value but underlying-pointer-checked) or a pointer-typed + expression. An integer RHS would lower to `__u32 x; if (x != NULL)`, + which warns under -Wpointer-integer-compare and is semantically + incoherent — also the evaluator's truthiness rules diverge from the + codegen's `!= NULL` for non-pointer types. *) +let test_reject_int_literal_rhs () = + let source = {| +@xdp fn probe(ctx: *xdp_md) -> xdp_action { + if (var x = 5) { + return XDP_PASS + } + return XDP_DROP +} +|} in + try + let _ = typecheck source in + fail "expected rejection of integer-literal RHS" + with + | Kernelscript.Type_checker.Type_error _ -> () + +(** 11b. Reject: non-pointer-returning function RHS. *) +let test_reject_non_pointer_call_rhs () = + let source = {| +@helper fn returns_zero() -> u32 { + return 0 +} + +@xdp fn probe(ctx: *xdp_md) -> xdp_action { + if (var x = returns_zero()) { + return XDP_PASS + } + return XDP_DROP +} +|} in + try + let _ = typecheck source in + fail "expected rejection of non-pointer-returning call as RHS" + with + | Kernelscript.Type_checker.Type_error _ -> () + +(** 11. Codegen (shadowing): an outer binding of the same name as the IfLet + binding must survive both branches and remain referenceable after the + if. The branch-local invariant the frontend enforces (binding visible + only inside the then-branch) has to be preserved end-to-end through + IR lowering — i.e., the inner binding cannot collapse onto the outer + name in the generated C. *) +let test_codegen_shadow_outer_binding () = + let source = {| +var counters : hash(1024) + +@xdp fn probe(ctx: *xdp_md) -> xdp_action { + var c : u64 = 100 + if (var c = counters[1]) { + return XDP_DROP + } + if (c == 100) { + return XDP_PASS + } + return XDP_DROP +} +|} in + let c = codegen_ebpf source in + (* The outer `c = 100` declaration must remain literally — the inner binding + must not reuse the name. *) + check bool "outer c declared with literal value" true + (contains_substr c "__u64 c = 100"); + (* The outer `c` must NOT be reassigned by the IfLet's lowering. The bug + symptom was a statement-expression assignment `c = ({ ... });` that + clobbered the outer binding with the lookup result (or zero on miss). + A bare `c = ({` (no `__u64` prefix) is the giveaway. *) + let outer_clobber = + try let _ = Str.search_forward + (Str.regexp "[^_a-zA-Z0-9]c = ({") c 0 in true + with Not_found -> false in + check bool "outer c not clobbered by iflet init" false outer_clobber; + (* The post-if comparison `c == 100` must reference the outer `c`, not be + rewritten into another fresh map deref. *) + check bool "post-if uses outer c by name" true + (contains_substr c "(c == 100)") + +let suite = [ + "parse_iflet_no_else", `Quick, test_parse_iflet_no_else; + "parse_iflet_with_else", `Quick, test_parse_iflet_with_else; + "parse_iflet_else_iflet", `Quick, test_parse_iflet_else_iflet; + "typecheck_struct_binding", `Quick, test_typecheck_struct_binding; + "typecheck_scalar_binding", `Quick, test_typecheck_scalar_binding; + "reject_binding_in_else", `Quick, test_reject_binding_in_else; + "reject_binding_after_if", `Quick, test_reject_binding_after_if; + "codegen_struct_in_place", `Quick, test_codegen_struct_in_place; + "codegen_scalar_value_binding", `Quick, test_codegen_scalar_value_binding; + "codegen_struct_value_binding_shape", `Quick, test_codegen_struct_value_binding_shape; + "codegen_shadow_outer_binding", `Quick, test_codegen_shadow_outer_binding; + "reject_int_literal_rhs", `Quick, test_reject_int_literal_rhs; + "reject_non_pointer_call_rhs", `Quick, test_reject_non_pointer_call_rhs; +] + +let () = + run "IfLet (declaration-as-condition)" [ "iflet", suite ] diff --git a/tests/test_map_integration.ml b/tests/test_map_integration.ml index 55bd065..5fd072c 100644 --- a/tests/test_map_integration.ml +++ b/tests/test_map_integration.ml @@ -430,7 +430,7 @@ var test_map : hash(64) @helper fn get_value(key: u32) -> TestEnum { var result = test_map[key] - if (result != none) { + if (result != null) { return result // This should return the dereferenced value, not pointer } else { return VALUE_A diff --git a/tests/test_map_operations.ml b/tests/test_map_operations.ml index 8b7fb64..7991836 100644 --- a/tests/test_map_operations.ml +++ b/tests/test_map_operations.ml @@ -78,7 +78,7 @@ let test_map_origin_conditional_assignments () = @xdp fn test_conditional(ctx: *xdp_md) -> xdp_action { var user_id: u32 = 123 var stats = user_stats[user_id] - if (stats != none) { + if (stats != null) { var local_stats = stats print("Stats: {}", local_stats) } @@ -113,7 +113,7 @@ let test_address_of_map_values () = @xdp fn test_address_of(ctx: *xdp_md) -> xdp_action { var user_id: u32 = 123 var stats = user_stats[user_id] - if (stats != none) { + if (stats != null) { var ptr = &stats print("Stats pointer: {}", ptr) } @@ -148,7 +148,7 @@ let test_address_of_type_checking () = @xdp fn test_address_of_types(ctx: *xdp_md) -> xdp_action { var user_id: u32 = 123 var stats = user_stats[user_id] - if (stats != none) { + if (stats != null) { var ptr: *u64 = &stats print("Stats value: {}", *ptr) } @@ -170,7 +170,7 @@ let test_address_of_contexts () = var user_id: u32 = 123 var stats = user_stats[user_id] - if (stats != none) { + if (stats != null) { // Address-of in if statement var ptr1 = &stats @@ -200,7 +200,7 @@ let test_none_comparison_map_values () = var user_id: u32 = 123 var stats = user_stats[user_id] - if (stats != none) { + if (stats != null) { print("Stats found: {}", stats) } else { print("Stats not found") @@ -226,17 +226,17 @@ let test_none_comparison_different_map_types () = var user_id: u32 = 123 var hash_stats = hash_map[user_id] - if (hash_stats != none) { + if (hash_stats != null) { print("Hash stats: {}", hash_stats) } var lru_stats = lru_map[user_id] - if (lru_stats != none) { + if (lru_stats != null) { print("LRU stats: {}", lru_stats) } var percpu_stats = percpu_map[user_id] - if (percpu_stats != none) { + if (percpu_stats != null) { print("PerCPU stats: {}", percpu_stats) } @@ -259,13 +259,13 @@ let test_none_comparison_conditional_statements () = var stats = user_stats[user_id] // Test in if statement - if (stats != none) { + if (stats != null) { var local_stats = stats print("Found stats: {}", local_stats) } // Test in while statement - while (stats != none) { + while (stats != null) { print("Processing stats: {}", stats) break } @@ -290,17 +290,17 @@ let test_none_comparison_different_value_types () = var key: u32 = 123 var u32_val = u32_map[key] - if (u32_val != none) { + if (u32_val != null) { print("U32 value: {}", u32_val) } var u64_val = u64_map[key] - if (u64_val != none) { + if (u64_val != null) { print("U64 value: {}", u64_val) } var bool_val = bool_map[key] - if (bool_val != none) { + if (bool_val != null) { print("Bool value: {}", bool_val) } @@ -324,7 +324,7 @@ let test_complex_map_value_scenarios () = var stats = user_stats[user_id] var counts = user_counts[user_id] - if (stats != none && counts != none) { + if (stats != null && counts != null) { var stats_ptr = &stats var counts_ptr = &counts @@ -356,7 +356,7 @@ let test_nested_map_value_access () = var current_id = user_id + i var stats = user_stats[current_id] - if (stats != none) { + if (stats != null) { var local_stats = stats var stats_ptr = &local_stats @@ -382,11 +382,12 @@ let test_nested_map_value_access () = (** Test error cases for map value operations *) let test_map_value_error_cases () = - (* Test 1: Invalid none comparison with non-map values *) + (* Test 1: comparing a regular variable against `null` parses fine + (the type-checker accepts any pointer-coercible comparison). *) let test_program1 = {| - @xdp fn test_invalid_none(ctx: *xdp_md) -> xdp_action { + @xdp fn test_null_compare(ctx: *xdp_md) -> xdp_action { var regular_var: u32 = 42 - if (regular_var != none) { // This should be an error + if (regular_var != null) { print("Regular var: {}", regular_var) } return 0 diff --git a/tests/test_truthy_falsy.ml b/tests/test_truthy_falsy.ml index 359aca7..06fd0c1 100644 --- a/tests/test_truthy_falsy.ml +++ b/tests/test_truthy_falsy.ml @@ -116,9 +116,6 @@ let test_truthy_evaluation () = check bool "Null pointer is falsy" (is_truthy_value (PointerValue 0)) false; check bool "Non-null pointer is truthy" (is_truthy_value (PointerValue 0x1234)) true; - (* Test none sentinel - always falsy *) - check bool "none is falsy" (is_truthy_value None) false; - (* Test that structs and arrays cannot be used in boolean context *) (try let _ = is_truthy_value (ArrayValue [||]) in diff --git a/tests/test_type_checker.ml b/tests/test_type_checker.ml index 9713bfc..3af05c1 100644 --- a/tests/test_type_checker.ml +++ b/tests/test_type_checker.ml @@ -1223,7 +1223,7 @@ var packet_filter : lru_hash(512) // Struct key var count2 = protocol_stats[proto] // Enum as key var result = packet_filter[info] // Struct as key - if (count1 != none && count2 != none && result != none) { + if (count1 != null && count2 != null && result != null) { return count1 + count2 + result } else { return 0