Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions agent/app/model/firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package model
type Firewall struct {
BaseModel

Type string `gorm:"not null" json:"type"`
FirewallType string `gorm:"not null" json:"firewallType"`
Port string `gorm:"not null" json:"port"` // Deprecated
Address string `gorm:"not null" json:"address"` // Deprecated
Type string `json:"type"`
FirewallType string `json:"firewallType"`
Port string `json:"port"` // Deprecated
Address string `json:"address"` // Deprecated

Chain string `json:"chain"`
Protocol string `json:"protocol"`
Expand Down
45 changes: 26 additions & 19 deletions agent/app/service/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,23 @@ func (s *IptablesService) OperateRule(req dto.IptablesRuleOp) error {
return fmt.Errorf("failed to save rule to database: %w", err)
}
return nil
}
if err := iptables.DeleteFilterRule(req.Chain, policy); err != nil {
return fmt.Errorf("failed to remove iptables rule: %w", err)
}
name := iptables.InputFileName
if req.Chain == iptables.Chain1PanelOutput {
name = iptables.OutputFileName
}
if err := iptables.SaveRulesToFile(iptables.FilterTab, req.Chain, name); err != nil {
global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelBasic, err)
}
if req.ID != 0 {
if err := hostRepo.DeleteFirewallRecordByID(req.ID); err != nil {
return fmt.Errorf("failed to delete rule from database: %w", err)
} else if req.Operation == "remove" {
if err := iptables.DeleteFilterRule(req.Chain, policy); err != nil {
return fmt.Errorf("failed to remove iptables rule: %w", err)
}
name := iptables.InputFileName
if req.Chain == iptables.Chain1PanelOutput {
name = iptables.OutputFileName
}
if err := iptables.SaveRulesToFile(iptables.FilterTab, req.Chain, name); err != nil {
global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelBasic, err)
}
if req.ID != 0 {
if err := hostRepo.DeleteFirewallRecordByID(req.ID); err != nil {
return fmt.Errorf("failed to delete rule from database: %w", err)
}
}
return nil
}
return nil
}
Expand All @@ -126,14 +128,14 @@ func (s *IptablesService) BatchOperate(req dto.IptablesBatchOperate) error {
}

func (s *IptablesService) Operate(req dto.IptablesOp) error {
targetChain := "INPUT"
if req.Name == "1PANEL_OUTPUT" {
targetChain = "OUTPUT"
targetChain := iptables.ChainInput
if req.Name == iptables.Chain1PanelOutput {
targetChain = iptables.ChainOutput
}
switch req.Operate {
case "init-base":
if err := cmd.RunDefaultBashC("modprobe ip_tables"); err != nil {
return fmt.Errorf("failed to load ip_tables module: %v", err)
if ok := cmd.Which("iptables"); !ok {
return fmt.Errorf("failed to find iptables")
}
if err := iptables.AddChain(iptables.FilterTab, iptables.Chain1PanelBasicBefore); err != nil {
return err
Expand Down Expand Up @@ -173,6 +175,7 @@ func (s *IptablesService) Operate(req dto.IptablesOp) error {
if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelInput, number); err != nil {
return err
}
return nil
case "bind-base":
if err := initPreRules(); err != nil {
return err
Expand All @@ -186,6 +189,7 @@ func (s *IptablesService) Operate(req dto.IptablesOp) error {
if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasicAfter, 3); err != nil {
return err
}
return nil
case "unbind-base":
if err := iptables.UnbindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasicAfter); err != nil {
return err
Expand All @@ -196,14 +200,17 @@ func (s *IptablesService) Operate(req dto.IptablesOp) error {
if err := iptables.UnbindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasic); err != nil {
return err
}
return nil
case "bind":
if err := iptables.BindChain(iptables.FilterTab, targetChain, req.Name, loadBindNumber()); err != nil {
return err
}
return nil
case "unbind":
if err := iptables.UnbindChain(iptables.FilterTab, targetChain, req.Name); err != nil {
return err
}
return nil
}
return nil
}
Expand Down
25 changes: 25 additions & 0 deletions agent/utils/firewall/client/iptables/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ type FilterRules struct {
}

func AddFilterRule(chain string, policy FilterRules) error {
if err := validateRuleSafety(policy, chain); err != nil {
return err
}
iptablesArg := fmt.Sprintf("-A %s", chain)
if policy.Protocol != "" {
iptablesArg += fmt.Sprintf(" -p %s", policy.Protocol)
Expand Down Expand Up @@ -130,3 +133,25 @@ func loadIP(ipStr string) string {
}
return ipStr
}

func validateRuleSafety(rule FilterRules, chain string) error {
if strings.ToUpper(rule.Strategy) != "DROP" {
return nil
}

// 入方向检查是否存在无条件 DROP
if chain == ChainInput || chain == Chain1PanelInput || chain == Chain1PanelBasic {
if rule.SrcIP == "0.0.0.0/0" && rule.SrcPort == 0 && rule.DstPort == 0 {
return fmt.Errorf("unsafe DROP is not allowed")
}
}

// 出方向检查是否存在无条件 DROP
if chain == ChainOutput || chain == Chain1PanelOutput || chain == Chain1PanelBasicAfter {
if rule.DstIP == "0.0.0.0/0" && rule.DstPort == 0 && rule.SrcPort == 0 {
return fmt.Errorf("unsafe DROP is not allowed")
}
}

return nil
}
Loading