Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions js/compressed-token/src/v3/actions/create-mint-interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ export async function createMintInterface(

// Default: light-token mint creation
if (!('secretKey' in mintAuthority)) {
throw new Error(
'mintAuthority must be a Signer for light-token mints',
);
throw new Error('mintAuthority must be a Signer for light-token mints');
}
if (
addressTreeInfo &&
Expand Down
4 changes: 1 addition & 3 deletions js/compressed-token/src/v3/get-mint-interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ export async function getMintInterface(
);

if (!compressedAccount?.data?.data) {
throw new Error(
`Light mint not found for ${address.toString()}`,
);
throw new Error(`Light mint not found for ${address.toString()}`);
}

const compressedData = Buffer.from(compressedAccount.data.data);
Expand Down
9 changes: 8 additions & 1 deletion sdk-libs/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,14 @@ pub fn light_account_derive(input: TokenStream) -> TokenStream {
/// - The `compression_info` field must be first or last field in the struct
/// - Struct should be `#[repr(C)]` for predictable memory layout
/// - Use `[u8; 32]` instead of `Pubkey` for address fields
#[proc_macro_derive(LightPinocchioAccount, attributes(compress_as, skip))]
///
/// ## Custom discriminator
///
/// Use `#[light_pinocchio(discriminator = [1u8])]` to override the default
/// 8-byte SHA256 discriminator with a shorter custom discriminator (1-8 bytes).
/// Variants with short discriminators should be declared last in `ProgramAccounts`
/// enums to avoid prefix-matching conflicts during dispatch.
#[proc_macro_derive(LightPinocchioAccount, attributes(compress_as, skip, light_pinocchio))]
pub fn light_pinocchio_account_derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
into_token_stream(light_pdas::account::derive::derive_light_pinocchio_account(
Expand Down
193 changes: 191 additions & 2 deletions sdk-libs/macros/src/light_pdas/account/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,68 @@ pub fn derive_light_pinocchio_account(input: DeriveInput) -> Result<TokenStream>
derive_light_account_internal(input, Framework::Pinocchio)
}

/// Parses the `discriminator` bytes from `#[light_pinocchio(discriminator = [...])]` if present.
/// Returns None if the attribute is absent (use hash-derived discriminator).
fn parse_pinocchio_discriminator(attrs: &[syn::Attribute]) -> Result<Option<Vec<u8>>> {
for attr in attrs {
if !attr.path().is_ident("light_pinocchio") {
continue;
}
let meta_list = attr.meta.require_list()?;
let nested: Punctuated<syn::Meta, Token![,]> =
meta_list.parse_args_with(Punctuated::parse_terminated)?;
for meta in &nested {
if let syn::Meta::NameValue(nv) = meta {
if nv.path.is_ident("discriminator") {
if let syn::Expr::Array(arr) = &nv.value {
let bytes: Vec<u8> = arr
.elems
.iter()
.map(|e| {
if let syn::Expr::Lit(lit) = e {
if let syn::Lit::Int(i) = &lit.lit {
return i
.base10_parse::<u8>()
.map_err(|err| syn::Error::new_spanned(i, err));
}
}
if let syn::Expr::Cast(cast) = e {
if let syn::Expr::Lit(lit) = cast.expr.as_ref() {
if let syn::Lit::Int(i) = &lit.lit {
return i
.base10_parse::<u8>()
.map_err(|err| syn::Error::new_spanned(i, err));
}
}
}
Err(syn::Error::new_spanned(e, "expected integer literal"))
})
.collect::<Result<Vec<u8>>>()?;
if bytes.is_empty() {
return Err(syn::Error::new_spanned(
arr,
"discriminator must have at least one byte",
));
}
if bytes.len() > 8 {
return Err(syn::Error::new_spanned(
arr,
"discriminator must not exceed 8 bytes",
));
}
return Ok(Some(bytes));
}
return Err(syn::Error::new_spanned(
&nv.value,
"discriminator must be an array like [1u8]",
));
}
}
}
}
Ok(None)
}

/// Internal implementation of LightAccount derive, parameterized by framework.
fn derive_light_account_internal(input: DeriveInput, framework: Framework) -> Result<TokenStream> {
// Convert DeriveInput to ItemStruct for macros that need it
Expand All @@ -125,8 +187,35 @@ fn derive_light_account_internal(input: DeriveInput, framework: Framework) -> Re
// Generate LightHasherSha implementation
let hasher_impl = derive_light_hasher_sha(item_struct.clone())?;

// Generate LightDiscriminator implementation
let discriminator_impl = discriminator::anchor_discriminator(item_struct)?;
// Check for custom discriminator argument from #[light_pinocchio(discriminator = [...])]
// Only valid for the Pinocchio framework; reject it on Anchor to avoid silent misuse.
let discriminator_impl = if let Some(disc_bytes) = parse_pinocchio_discriminator(&input.attrs)?
{
if framework != Framework::Pinocchio {
return Err(syn::Error::new_spanned(
&input.ident,
"#[light_pinocchio(discriminator = [...])] is only valid with \
#[derive(LightPinocchioAccount)], not with #[derive(LightAccount)]",
));
}
let mut padded = [0u8; 8];
let copy_len = disc_bytes.len().min(8);
padded[..copy_len].copy_from_slice(&disc_bytes[..copy_len]);
let discriminator_tokens: proc_macro2::TokenStream = format!("{padded:?}").parse().unwrap();
let slice_tokens: proc_macro2::TokenStream = format!("{disc_bytes:?}").parse().unwrap();
let struct_name = &input.ident;
let (impl_gen, type_gen, where_clause) = input.generics.split_for_impl();
quote! {
impl #impl_gen LightDiscriminator for #struct_name #type_gen #where_clause {
const LIGHT_DISCRIMINATOR: [u8; 8] = #discriminator_tokens;
const LIGHT_DISCRIMINATOR_SLICE: &'static [u8] = &#slice_tokens;
fn discriminator() -> [u8; 8] { Self::LIGHT_DISCRIMINATOR }
}
}
} else {
// Generate LightDiscriminator implementation via SHA256
discriminator::anchor_discriminator(item_struct)?
};

// Generate unified LightAccount implementation (includes PackedXxx struct)
let light_account_impl = generate_light_account_impl(&input, framework)?;
Expand Down Expand Up @@ -747,6 +836,106 @@ mod tests {

use super::*;

#[test]
fn test_light_pinocchio_custom_discriminator() {
let input: DeriveInput = parse_quote! {
#[light_pinocchio(discriminator = [1u8])]
pub struct OneByteRecord {
pub compression_info: CompressionInfo,
pub owner: [u8; 32],
}
};

let result = derive_light_pinocchio_account(input);
assert!(
result.is_ok(),
"LightPinocchioAccount with custom discriminator should succeed: {:?}",
result.err()
);

let output = result.unwrap().to_string();

// Should contain custom discriminator (1, 0, 0, 0, 0, 0, 0, 0)
assert!(
output.contains("LIGHT_DISCRIMINATOR"),
"Should have LIGHT_DISCRIMINATOR"
);
assert!(
output.contains("1 , 0 , 0 , 0 , 0 , 0 , 0 , 0")
|| output.contains("1, 0, 0, 0, 0, 0, 0, 0"),
"LIGHT_DISCRIMINATOR should be [1,0,0,0,0,0,0,0]"
);
// LIGHT_DISCRIMINATOR_SLICE must be &[1] (1 byte), NOT the padded &[1, 0, 0, 0, 0, 0, 0, 0]
assert!(
output.contains("LIGHT_DISCRIMINATOR_SLICE"),
"Should have LIGHT_DISCRIMINATOR_SLICE"
);
// Verify the slice contains exactly 1 element (not 8)
// The generated token stream renders as `& [1u8]` or `& [1]`
assert!(
output.contains("& [1u8]") || output.contains("& [1]"),
"LIGHT_DISCRIMINATOR_SLICE should be &[1] (1 byte), got: {output}"
);
}

#[test]
fn test_light_pinocchio_custom_discriminator_empty_rejected() {
let input: DeriveInput = parse_quote! {
#[light_pinocchio(discriminator = [])]
pub struct EmptyDisc {
pub compression_info: CompressionInfo,
pub owner: [u8; 32],
}
};
let result = derive_light_pinocchio_account(input);
assert!(
result.is_err(),
"Empty discriminator array should be rejected"
);
let err = result.unwrap_err().to_string();
assert!(
err.contains("at least one byte"),
"Error should mention 'at least one byte', got: {err}"
);
}

#[test]
fn test_light_pinocchio_custom_discriminator_too_long_rejected() {
let input: DeriveInput = parse_quote! {
#[light_pinocchio(discriminator = [1, 2, 3, 4, 5, 6, 7, 8, 9])]
pub struct TooLongDisc {
pub compression_info: CompressionInfo,
pub owner: [u8; 32],
}
};
let result = derive_light_pinocchio_account(input);
assert!(
result.is_err(),
"Discriminator longer than 8 bytes should be rejected"
);
let err = result.unwrap_err().to_string();
assert!(
err.contains("exceed 8 bytes"),
"Error should mention max length, got: {err}"
);
}

#[test]
fn test_light_pinocchio_discriminator_rejected_on_anchor() {
let input: DeriveInput = parse_quote! {
#[light_pinocchio(discriminator = [1u8])]
pub struct AnchorRecord {
pub compression_info: CompressionInfo,
pub owner: Pubkey,
}
};
let result = derive_light_account(input);
assert!(
result.is_err(),
"#[light_pinocchio(discriminator)] should be rejected with LightAccount (Anchor)"
);
}

#[test]
fn test_light_account_basic() {
let input: DeriveInput = parse_quote! {
Expand Down
Loading