Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions config/contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type Contract struct {
Location string
Aliases Aliases
IsDependency bool
Canonical string // Reference to canonical contract name if this is an alias
}

// Alias defines an existing pre-deployed contract address for specific network.
Expand Down Expand Up @@ -74,6 +75,19 @@ func (c *Contract) IsAliased() bool {
return len(c.Aliases) > 0
}

// IsAlias checks if this contract is an alias to another contract.
func (c *Contract) IsAlias() bool {
return c.Canonical != ""
}

// CanonicalName returns the canonical contract name if this is an alias, otherwise returns the contract's own name.
func (c *Contract) CanonicalName() string {
if c.Canonical != "" {
return c.Canonical
}
return c.Name
}

// ByName get contract by name or return an error if it doesn't exist.
func (c *Contracts) ByName(name string) (*Contract, error) {
for i, contract := range *c {
Expand Down Expand Up @@ -112,6 +126,30 @@ func (c *Contracts) Remove(name string) error {
return nil
}

// ValidateCanonical validates that all canonical references are valid.
func (c *Contracts) ValidateCanonical() error {
for _, contract := range *c {
if contract.Canonical != "" {
// Check self-reference
if contract.Canonical == contract.Name {
return fmt.Errorf("contract %s cannot have itself as canonical", contract.Name)
}
}
}
return nil
}

// GetAliases returns all contracts that have the given contract as their canonical.
func (c *Contracts) GetAliases(canonicalName string) []*Contract {
var aliases []*Contract
for i, contract := range *c {
if contract.Canonical == canonicalName {
aliases = append(aliases, &(*c)[i])
}
}
return aliases
}

const dependencyManagerDirectory = "imports"

const (
Expand Down
119 changes: 119 additions & 0 deletions config/contract_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,122 @@ func TestContracts_AddDependencyAsContract(t *testing.T) {
assert.Equal(t, "imports/0000000000abcdef/TestContract.cdc", contract.Location)
assert.Len(t, contract.Aliases, 1)
}

func TestContract_IsAlias(t *testing.T) {
tests := []struct {
name string
contract Contract
expected bool
}{
{
name: "contract with canonical is an alias",
contract: Contract{Name: "FUSD1", Canonical: "FUSD"},
expected: true,
},
{
name: "contract without canonical is not an alias",
contract: Contract{Name: "FUSD"},
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, tt.contract.IsAlias())
})
}
}

func TestContract_CanonicalName(t *testing.T) {
tests := []struct {
name string
contract Contract
expected string
}{
{
name: "alias returns canonical name",
contract: Contract{Name: "FUSD1", Canonical: "FUSD"},
expected: "FUSD",
},
{
name: "non-alias returns its own name",
contract: Contract{Name: "FUSD"},
expected: "FUSD",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, tt.contract.CanonicalName())
})
}
}

func TestContracts_ValidateCanonical(t *testing.T) {
tests := []struct {
name string
contracts Contracts
wantErr bool
errMsg string
}{
{
name: "valid canonical reference",
contracts: Contracts{
{Name: "FUSD", Location: "FUSD.cdc"},
{Name: "FUSD1", Location: "FUSD.cdc", Canonical: "FUSD"},
},
wantErr: false,
},
{
name: "self-referential canonical",
contracts: Contracts{
{Name: "FUSD", Location: "FUSD.cdc", Canonical: "FUSD"},
},
wantErr: true,
errMsg: "contract FUSD cannot have itself as canonical",
},
{
name: "multiple aliases to same canonical",
contracts: Contracts{
{Name: "FUSD", Location: "FUSD.cdc"},
{Name: "FUSD1", Location: "FUSD.cdc", Canonical: "FUSD"},
{Name: "FUSD2", Location: "FUSD.cdc", Canonical: "FUSD"},
},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.contracts.ValidateCanonical()
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
} else {
assert.NoError(t, err)
}
})
}
}

func TestContracts_GetAliases(t *testing.T) {
contracts := Contracts{
{Name: "FUSD", Location: "FUSD.cdc"},
{Name: "FUSD1", Location: "FUSD.cdc", Canonical: "FUSD"},
{Name: "FUSD2", Location: "FUSD.cdc", Canonical: "FUSD"},
{Name: "FT", Location: "FT.cdc"},
{Name: "FT1", Location: "FT.cdc", Canonical: "FT"},
}

fusdAliases := contracts.GetAliases("FUSD")
assert.Len(t, fusdAliases, 2)
assert.Equal(t, "FUSD1", fusdAliases[0].Name)
assert.Equal(t, "FUSD2", fusdAliases[1].Name)

ftAliases := contracts.GetAliases("FT")
assert.Len(t, ftAliases, 1)
assert.Equal(t, "FT1", ftAliases[0].Name)

noAliases := contracts.GetAliases("NonExistent")
assert.Len(t, noAliases, 0)
}
19 changes: 11 additions & 8 deletions config/json/contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ func (j jsonContracts) transformToConfig() (config.Contracts, error) {
contracts = append(contracts, contract)
} else {
contract := config.Contract{
Name: contractName,
Location: c.Advanced.Source,
Name: contractName,
Location: c.Advanced.Source,
Canonical: c.Advanced.Canonical,
}
for network, alias := range c.Advanced.Aliases {
address := flow.HexToAddress(alias)
Expand All @@ -73,8 +74,8 @@ func transformContractsToJSON(contracts config.Contracts) jsonContracts {
continue
}

// if simple case
if !c.IsAliased() {
// if simple case (no aliases and no canonical)
if !c.IsAliased() && c.Canonical == "" {
jsonContracts[c.Name] = jsonContract{
Simple: filepath.ToSlash(c.Location),
}
Expand All @@ -87,8 +88,9 @@ func transformContractsToJSON(contracts config.Contracts) jsonContracts {

jsonContracts[c.Name] = jsonContract{
Advanced: jsonContractAdvanced{
Source: filepath.ToSlash(c.Location),
Aliases: aliases,
Source: filepath.ToSlash(c.Location),
Aliases: aliases,
Canonical: c.Canonical,
},
}
}
Expand All @@ -99,8 +101,9 @@ func transformContractsToJSON(contracts config.Contracts) jsonContracts {

// jsonContractAdvanced for json parsing advanced config.
type jsonContractAdvanced struct {
Source string `json:"source"`
Aliases map[string]string `json:"aliases"`
Source string `json:"source"`
Aliases map[string]string `json:"aliases"`
Canonical string `json:"canonical,omitempty"`
}

// jsonContract structure for json parsing.
Expand Down
4 changes: 4 additions & 0 deletions flowkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ func (f *Flowkit) AddContract(
importReplacer := project.NewImportReplacer(
contracts,
state.AliasesForNetwork(f.network),
state.CanonicalContractMapping(),
)

program, err = importReplacer.Replace(program)
Expand Down Expand Up @@ -833,6 +834,7 @@ func (f *Flowkit) ExecuteScript(ctx context.Context, script Script, query Script
importReplacer := project.NewImportReplacer(
contracts,
state.AliasesForNetwork(f.network),
state.CanonicalContractMapping(),
)

if state == nil {
Expand Down Expand Up @@ -990,6 +992,7 @@ func (f *Flowkit) BuildTransaction(
importReplacer := project.NewImportReplacer(
contracts,
state.AliasesForNetwork(f.network),
state.CanonicalContractMapping(),
)

program, err = importReplacer.Replace(program)
Expand Down Expand Up @@ -1122,6 +1125,7 @@ func (f *Flowkit) ReplaceImportsInScript(
importReplacer := project.NewImportReplacer(
contracts,
state.AliasesForNetwork(f.network),
state.CanonicalContractMapping(),
)

program, err := project.NewProgram(script.Code, script.Args, script.Location)
Expand Down
45 changes: 38 additions & 7 deletions project/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,22 @@ type Account interface {

// ImportReplacer implements file import replacements functionality for the project contracts with optionally included aliases.
type ImportReplacer struct {
contracts []*Contract
aliases LocationAliases
contracts []*Contract
aliases LocationAliases
canonicalMapping map[string]string // maps alias names to their canonical contract names
}

func NewImportReplacer(contracts []*Contract, aliases LocationAliases) *ImportReplacer {
func NewImportReplacer(contracts []*Contract, aliases LocationAliases, canonicalMapping ...map[string]string) *ImportReplacer {
canonical := make(map[string]string)
// If canonical mapping is provided, use it
if len(canonicalMapping) > 0 && canonicalMapping[0] != nil {
canonical = canonicalMapping[0]
}

return &ImportReplacer{
contracts: contracts,
aliases: aliases,
contracts: contracts,
aliases: aliases,
canonicalMapping: canonical,
}
}

Expand All @@ -52,13 +60,17 @@ func (i *ImportReplacer) Replace(program *Program) (*Program, error) {
importLocation := filepath.Clean(absolutePath(program.Location(), imp))
address, isPath := contractsLocations[importLocation]
if isPath {
program.replaceImport(imp, address)
// Check if this import is an alias
canonicalName := i.getCanonicalNameForImport(imp, address)
program.replaceImport(imp, address, canonicalName)
continue
}
// check if import by identifier exists (e.g. import ["X"])
address, isIdentifier := contractsLocations[imp]
if isIdentifier {
program.replaceImport(imp, address)
// Check if this import is an alias
canonicalName := i.getCanonicalNameForImport(imp, address)
program.replaceImport(imp, address, canonicalName)
continue
}

Expand All @@ -84,6 +96,25 @@ func (i *ImportReplacer) getContractsLocations() map[string]string {
return locationAddress
}

// getCanonicalNameForImport determines the canonical contract name for an import.
// Returns the canonical name if the import is an alias, otherwise returns the import name.
func (i *ImportReplacer) getCanonicalNameForImport(importName string, address string) string {
// Extract just the contract name from the import path if it's a path
contractName := importName
if filepath.Ext(importName) == ".cdc" {
contractName = filepath.Base(importName)
contractName = contractName[:len(contractName)-4] // Remove .cdc extension
}

// Check if this is an alias by looking up in canonical mapping
if canonicalName, isAlias := i.canonicalMapping[contractName]; isAlias {
return canonicalName
}

// Not an alias, return the original contract name
return contractName
}

func absolutePath(basePath, relativePath string) string {
return filepath.Join(filepath.Dir(basePath), relativePath)
}
Loading
Loading