diff --git a/agent/app/model/firewall.go b/agent/app/model/firewall.go index 7218d4238521..87fbb7efcd14 100644 --- a/agent/app/model/firewall.go +++ b/agent/app/model/firewall.go @@ -3,10 +3,10 @@ package model type Firewall struct { BaseModel - Type string `gorm:"not null" json:"type"` - FirewallType string `gorm:"not null" json:"firewallType"` - Port string `gorm:"not null" json:"port"` // Deprecated - Address string `gorm:"not null" json:"address"` // Deprecated + Type string `json:"type"` + FirewallType string `json:"firewallType"` + Port string `json:"port"` // Deprecated + Address string `json:"address"` // Deprecated Chain string `json:"chain"` Protocol string `json:"protocol"` diff --git a/agent/app/service/iptables.go b/agent/app/service/iptables.go index 28ea50d71ef4..f86889cda3ac 100644 --- a/agent/app/service/iptables.go +++ b/agent/app/service/iptables.go @@ -97,21 +97,23 @@ func (s *IptablesService) OperateRule(req dto.IptablesRuleOp) error { return fmt.Errorf("failed to save rule to database: %w", err) } return nil - } - if err := iptables.DeleteFilterRule(req.Chain, policy); err != nil { - return fmt.Errorf("failed to remove iptables rule: %w", err) - } - name := iptables.InputFileName - if req.Chain == iptables.Chain1PanelOutput { - name = iptables.OutputFileName - } - if err := iptables.SaveRulesToFile(iptables.FilterTab, req.Chain, name); err != nil { - global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelBasic, err) - } - if req.ID != 0 { - if err := hostRepo.DeleteFirewallRecordByID(req.ID); err != nil { - return fmt.Errorf("failed to delete rule from database: %w", err) + } else if req.Operation == "remove" { + if err := iptables.DeleteFilterRule(req.Chain, policy); err != nil { + return fmt.Errorf("failed to remove iptables rule: %w", err) + } + name := iptables.InputFileName + if req.Chain == iptables.Chain1PanelOutput { + name = iptables.OutputFileName + } + if err := iptables.SaveRulesToFile(iptables.FilterTab, req.Chain, name); err != nil { + global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelBasic, err) + } + if req.ID != 0 { + if err := hostRepo.DeleteFirewallRecordByID(req.ID); err != nil { + return fmt.Errorf("failed to delete rule from database: %w", err) + } } + return nil } return nil } @@ -126,14 +128,14 @@ func (s *IptablesService) BatchOperate(req dto.IptablesBatchOperate) error { } func (s *IptablesService) Operate(req dto.IptablesOp) error { - targetChain := "INPUT" - if req.Name == "1PANEL_OUTPUT" { - targetChain = "OUTPUT" + targetChain := iptables.ChainInput + if req.Name == iptables.Chain1PanelOutput { + targetChain = iptables.ChainOutput } switch req.Operate { case "init-base": - if err := cmd.RunDefaultBashC("modprobe ip_tables"); err != nil { - return fmt.Errorf("failed to load ip_tables module: %v", err) + if ok := cmd.Which("iptables"); !ok { + return fmt.Errorf("failed to find iptables") } if err := iptables.AddChain(iptables.FilterTab, iptables.Chain1PanelBasicBefore); err != nil { return err @@ -173,6 +175,7 @@ func (s *IptablesService) Operate(req dto.IptablesOp) error { if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelInput, number); err != nil { return err } + return nil case "bind-base": if err := initPreRules(); err != nil { return err @@ -186,6 +189,7 @@ func (s *IptablesService) Operate(req dto.IptablesOp) error { if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasicAfter, 3); err != nil { return err } + return nil case "unbind-base": if err := iptables.UnbindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasicAfter); err != nil { return err @@ -196,14 +200,17 @@ func (s *IptablesService) Operate(req dto.IptablesOp) error { if err := iptables.UnbindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasic); err != nil { return err } + return nil case "bind": if err := iptables.BindChain(iptables.FilterTab, targetChain, req.Name, loadBindNumber()); err != nil { return err } + return nil case "unbind": if err := iptables.UnbindChain(iptables.FilterTab, targetChain, req.Name); err != nil { return err } + return nil } return nil } diff --git a/agent/utils/firewall/client/iptables/filter.go b/agent/utils/firewall/client/iptables/filter.go index f476350984d3..21dda8fd7fda 100644 --- a/agent/utils/firewall/client/iptables/filter.go +++ b/agent/utils/firewall/client/iptables/filter.go @@ -21,6 +21,9 @@ type FilterRules struct { } func AddFilterRule(chain string, policy FilterRules) error { + if err := validateRuleSafety(policy, chain); err != nil { + return err + } iptablesArg := fmt.Sprintf("-A %s", chain) if policy.Protocol != "" { iptablesArg += fmt.Sprintf(" -p %s", policy.Protocol) @@ -130,3 +133,25 @@ func loadIP(ipStr string) string { } return ipStr } + +func validateRuleSafety(rule FilterRules, chain string) error { + if strings.ToUpper(rule.Strategy) != "DROP" { + return nil + } + + // 入方向检查是否存在无条件 DROP + if chain == ChainInput || chain == Chain1PanelInput || chain == Chain1PanelBasic { + if rule.SrcIP == "0.0.0.0/0" && rule.SrcPort == 0 && rule.DstPort == 0 { + return fmt.Errorf("unsafe DROP is not allowed") + } + } + + // 出方向检查是否存在无条件 DROP + if chain == ChainOutput || chain == Chain1PanelOutput || chain == Chain1PanelBasicAfter { + if rule.DstIP == "0.0.0.0/0" && rule.DstPort == 0 && rule.SrcPort == 0 { + return fmt.Errorf("unsafe DROP is not allowed") + } + } + + return nil +}