Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions src/InternPool.zig
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ pub const Key = union(enum) {

/// The returned pointer expires with any addition to the `InternPool`.
/// Asserts the struct is not packed.
pub fn flagsPtr(self: @This(), ip: *InternPool) *Tag.TypeStruct.Flags {
pub fn flagsPtr(self: @This(), ip: *const InternPool) *Tag.TypeStruct.Flags {
assert(self.layout != .Packed);
const flags_field_index = std.meta.fieldIndex(Tag.TypeStruct, "flags").?;
return @ptrCast(&ip.extra.items[self.extra_index + flags_field_index]);
Expand Down Expand Up @@ -687,6 +687,18 @@ pub const Key = union(enum) {
return @ptrCast(&ip.extra.items[self.extra_index + flags_field_index]);
}

/// The returned pointer expires with any addition to the `InternPool`.
pub fn size(self: @This(), ip: *InternPool) *u32 {
const size_field_index = std.meta.fieldIndex(Tag.TypeUnion, "size").?;
return &ip.extra.items[self.extra_index + size_field_index];
}

/// The returned pointer expires with any addition to the `InternPool`.
pub fn padding(self: @This(), ip: *InternPool) *u32 {
const padding_field_index = std.meta.fieldIndex(Tag.TypeUnion, "padding").?;
return &ip.extra.items[self.extra_index + padding_field_index];
}

pub fn haveFieldTypes(self: @This(), ip: *const InternPool) bool {
return self.flagsPtr(ip).status.haveFieldTypes();
}
Expand Down Expand Up @@ -1744,6 +1756,10 @@ pub const UnionType = struct {
enum_tag_ty: Index,
/// The integer tag type of the enum.
int_tag_ty: Index,
/// ABI size of the union, including padding
size: u64,
/// Trailing padding bytes
padding: u32,
/// List of field names in declaration order.
field_names: NullTerminatedString.Slice,
/// List of field types in declaration order.
Expand Down Expand Up @@ -1830,6 +1846,10 @@ pub const UnionType = struct {
return self.flagsPtr(ip).runtime_tag.hasTag();
}

pub fn haveFieldTypes(self: UnionType, ip: *const InternPool) bool {
return self.flagsPtr(ip).status.haveFieldTypes();
}

pub fn haveLayout(self: UnionType, ip: *const InternPool) bool {
return self.flagsPtr(ip).status.haveLayout();
}
Expand Down Expand Up @@ -1867,6 +1887,8 @@ pub fn loadUnionType(ip: *InternPool, key: Key.UnionType) UnionType {
.namespace = type_union.data.namespace,
.enum_tag_ty = enum_ty,
.int_tag_ty = enum_info.tag_ty,
.size = type_union.data.padding,
.padding = type_union.data.padding,
.field_names = enum_info.names,
.names_map = enum_info.names_map,
.field_types = .{
Expand Down Expand Up @@ -2936,6 +2958,10 @@ pub const Tag = enum(u8) {
/// 1. field align: Alignment for each field; declaration order
pub const TypeUnion = struct {
flags: Flags,
/// Only valid after .have_layout
size: u32,
/// Only valid after .have_layout
padding: u32,
decl: Module.Decl.Index,
namespace: Module.Namespace.Index,
/// The enum that provides the list of field names and values.
Expand All @@ -2950,7 +2976,9 @@ pub const Tag = enum(u8) {
status: UnionType.Status,
requires_comptime: RequiresComptime,
assumed_runtime_bits: bool,
_: u21 = 0,
assumed_pointer_aligned: bool,
alignment: Alignment,
_: u14 = 0,
};
};

Expand Down Expand Up @@ -3014,7 +3042,7 @@ pub const Tag = enum(u8) {
any_comptime_fields: bool,
any_default_inits: bool,
any_aligned_fields: bool,
/// `undefined` until the layout_resolved
/// `.none` until layout_resolved
alignment: Alignment,
/// Dependency loop detection when resolving struct alignment.
alignment_wip: bool,
Expand Down Expand Up @@ -5255,6 +5283,8 @@ pub fn getUnionType(ip: *InternPool, gpa: Allocator, ini: UnionTypeInit) Allocat

const union_type_extra_index = ip.addExtraAssumeCapacity(Tag.TypeUnion{
.flags = ini.flags,
.size = std.math.maxInt(u32),
.padding = std.math.maxInt(u32),
.decl = ini.decl,
.namespace = ini.namespace,
.tag_ty = ini.enum_tag_ty,
Expand Down
26 changes: 3 additions & 23 deletions src/Module.zig
Original file line number Diff line number Diff line change
Expand Up @@ -6499,31 +6499,11 @@ pub fn getUnionLayout(mod: *Module, u: InternPool.UnionType) UnionLayout {
.padding = 0,
};
}
// Put the tag before or after the payload depending on which one's
// alignment is greater.

const tag_size = u.enum_tag_ty.toType().abiSize(mod);
const tag_align = u.enum_tag_ty.toType().abiAlignment(mod).max(.@"1");
var size: u64 = 0;
var padding: u32 = undefined;
if (tag_align.compare(.gte, payload_align)) {
// {Tag, Payload}
size += tag_size;
size = payload_align.forward(size);
size += payload_size;
const prev_size = size;
size = tag_align.forward(size);
padding = @intCast(size - prev_size);
} else {
// {Payload, Tag}
size += payload_size;
size = tag_align.forward(size);
size += tag_size;
const prev_size = size;
size = payload_align.forward(size);
padding = @intCast(size - prev_size);
}
return .{
.abi_size = size,
.abi_size = u.size,
.abi_align = tag_align.max(payload_align),
.most_aligned_field = most_aligned_field,
.most_aligned_field_size = most_aligned_field_size,
Expand All @@ -6532,7 +6512,7 @@ pub fn getUnionLayout(mod: *Module, u: InternPool.UnionType) UnionLayout {
.payload_align = payload_align,
.tag_align = tag_align,
.tag_size = tag_size,
.padding = padding,
.padding = u.padding,
};
}

Expand Down
130 changes: 121 additions & 9 deletions src/Sema.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3193,6 +3193,8 @@ fn zirUnionDecl(
.any_aligned_fields = small.any_aligned_fields,
.requires_comptime = .unknown,
.assumed_runtime_bits = false,
.assumed_pointer_aligned = false,
.alignment = .none,
},
.decl = new_decl_index,
.namespace = new_namespace_index,
Expand Down Expand Up @@ -20962,6 +20964,8 @@ fn zirReify(
.any_aligned_fields = any_aligned_fields,
.requires_comptime = .unknown,
.assumed_runtime_bits = false,
.assumed_pointer_aligned = false,
.alignment = .none,
},
.field_types = union_fields.items(.type),
.field_aligns = if (any_aligned_fields) union_fields.items(.alignment) else &.{},
Expand Down Expand Up @@ -34895,11 +34899,59 @@ fn checkMemOperand(sema: *Sema, block: *Block, src: LazySrcLoc, ty: Type) !void
return sema.failWithOwnedErrorMsg(block, msg);
}

/// Resolve a unions's alignment only without triggering resolution of its layout.
/// Asserts that the alignment is not yet resolved.
pub fn resolveUnionAlignment(
sema: *Sema,
ty: Type,
union_type: InternPool.Key.UnionType,
) CompileError!Alignment {
const mod = sema.mod;
const ip = &mod.intern_pool;
const target = mod.getTarget();

assert(!union_type.haveLayout(ip));

if (union_type.flagsPtr(ip).status == .field_types_wip) {
// We'll guess "pointer-aligned", if the union has an
// underaligned pointer field then some allocations
// might require explicit alignment.
union_type.flagsPtr(ip).assumed_pointer_aligned = true;
const result = Alignment.fromByteUnits(@divExact(target.ptrBitWidth(), 8));
union_type.flagsPtr(ip).alignment = result;
return result;
}

try sema.resolveTypeFieldsUnion(ty, union_type);

const union_obj = ip.loadUnionType(union_type);
var max_align: Alignment = .@"1";
for (0..union_obj.field_names.len) |field_index| {
const field_ty = union_obj.field_types.get(ip)[field_index].toType();
if (!(try sema.typeHasRuntimeBits(field_ty))) continue;

const explicit_align = union_obj.fieldAlign(ip, @intCast(field_index));
const field_align = if (explicit_align != .none)
explicit_align
else
try sema.typeAbiAlignment(field_ty);

max_align = max_align.max(field_align);
}

union_type.flagsPtr(ip).alignment = max_align;
return max_align;
}

/// This logic must be kept in sync with `Module.getUnionLayout`.
fn resolveUnionLayout(sema: *Sema, ty: Type) CompileError!void {
const mod = sema.mod;
const ip = &mod.intern_pool;
try sema.resolveTypeFields(ty);
const union_obj = mod.typeToUnion(ty).?;

const union_type = ip.indexToKey(ty.ip_index).union_type;
try sema.resolveTypeFieldsUnion(ty, union_type);

const union_obj = ip.loadUnionType(union_type);
switch (union_obj.flagsPtr(ip).status) {
.none, .have_field_types => {},
.field_types_wip, .layout_wip => {
Expand All @@ -34913,25 +34965,74 @@ fn resolveUnionLayout(sema: *Sema, ty: Type) CompileError!void {
},
.have_layout, .fully_resolved_wip, .fully_resolved => return,
}

const prev_status = union_obj.flagsPtr(ip).status;
errdefer if (union_obj.flagsPtr(ip).status == .layout_wip) {
union_obj.flagsPtr(ip).status = prev_status;
};

union_obj.flagsPtr(ip).status = .layout_wip;
for (0..union_obj.field_types.len) |field_index| {

var max_size: u64 = 0;
var max_align: Alignment = .@"1";
for (0..union_obj.field_names.len) |field_index| {
const field_ty = union_obj.field_types.get(ip)[field_index].toType();
sema.resolveTypeLayout(field_ty) catch |err| switch (err) {
if (!(try sema.typeHasRuntimeBits(field_ty))) continue;

max_size = @max(max_size, sema.typeAbiSize(field_ty) catch |err| switch (err) {
error.AnalysisFail => {
const msg = sema.err orelse return err;
try sema.addFieldErrNote(ty, field_index, msg, "while checking this field", .{});
return err;
},
else => return err,
};
}
union_obj.flagsPtr(ip).status = .have_layout;
_ = try sema.typeRequiresComptime(ty);
});

const explicit_align = union_obj.fieldAlign(ip, @intCast(field_index));
const field_align = if (explicit_align != .none)
explicit_align
else
try sema.typeAbiAlignment(field_ty);

max_align = max_align.max(field_align);
}

const flags = union_obj.flagsPtr(ip);
const has_runtime_tag = flags.runtime_tag.hasTag() and try sema.typeHasRuntimeBits(union_obj.enum_tag_ty.toType());
const size, const alignment, const padding = if (has_runtime_tag) layout: {
const enum_tag_type = union_obj.enum_tag_ty.toType();
const tag_align = try sema.typeAbiAlignment(enum_tag_type);
const tag_size = try sema.typeAbiSize(enum_tag_type);

// Put the tag before or after the payload depending on which one's
// alignment is greater.
var size: u64 = 0;
var padding: u32 = 0;
if (tag_align.compare(.gte, max_align)) {
// {Tag, Payload}
size += tag_size;
size = max_align.forward(size);
size += max_size;
const prev_size = size;
size = tag_align.forward(size);
padding = @intCast(size - prev_size);
} else {
// {Payload, Tag}
size += max_size;
size = tag_align.forward(size);
size += tag_size;
const prev_size = size;
size = max_align.forward(size);
padding = @intCast(size - prev_size);
}

break :layout .{ size, max_align.max(tag_align), padding };
} else .{ max_align.forward(max_size), max_align, 0 };

union_type.size(ip).* = @intCast(size);
union_type.padding(ip).* = padding;
flags.alignment = alignment;
flags.status = .have_layout;

if (union_obj.flagsPtr(ip).assumed_runtime_bits and !(try sema.typeHasRuntimeBits(ty))) {
const msg = try Module.ErrorMsg.create(
Expand All @@ -34942,6 +35043,18 @@ fn resolveUnionLayout(sema: *Sema, ty: Type) CompileError!void {
);
return sema.failWithOwnedErrorMsg(null, msg);
}

if (union_obj.flagsPtr(ip).assumed_pointer_aligned and
alignment.compareStrict(.neq, Alignment.fromByteUnits(@divExact(mod.getTarget().ptrBitWidth(), 8))))
{
const msg = try Module.ErrorMsg.create(
sema.gpa,
mod.declPtr(union_obj.decl).srcLoc(mod),
"union layout depends on being pointer aligned",
.{},
);
return sema.failWithOwnedErrorMsg(null, msg);
}
}

/// Returns `error.AnalysisFail` if any of the types (recursively) failed to
Expand Down Expand Up @@ -35008,7 +35121,6 @@ fn resolveStructFully(sema: *Sema, ty: Type) CompileError!void {

fn resolveUnionFully(sema: *Sema, ty: Type) CompileError!void {
try sema.resolveUnionLayout(ty);
try sema.resolveTypeFields(ty);

const mod = sema.mod;
const ip = &mod.intern_pool;
Expand Down
Loading