Use linter code standards

This commit is contained in:
Fredrik Berntsson
2025-04-20 00:11:18 +02:00
committed by lovelaze
parent 9df4e352e4
commit dcd44b726a
31 changed files with 441 additions and 543 deletions
+128 -124
View File
@@ -30,116 +30,116 @@ formatters:
linters:
enable:
- asasalint # checks for pass []any as any in variadic func(...any)
- asciicheck # checks that your code does not contain non-ASCII identifiers
- bidichk # checks for dangerous unicode character sequences
- bodyclose # checks whether HTTP response body is closed successfully
- canonicalheader # checks whether net/http.Header uses canonical header
- copyloopvar # detects places where loop variables are copied (Go 1.22+)
- cyclop # checks function and package cyclomatic complexity
- depguard # checks if package imports are in a list of acceptable packages
- dupl # tool for code clone detection
- durationcheck # checks for two durations multiplied together
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
- errname # checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error
- errorlint # finds code that will cause problems with the error wrapping scheme introduced in Go 1.13
- exhaustive # checks exhaustiveness of enum switch statements
- exptostd # detects functions from golang.org/x/exp/ that can be replaced by std functions
- fatcontext # detects nested contexts in loops
- forbidigo # forbids identifiers
- funcorder # checks the order of functions, methods, and constructors
- funlen # tool for detection of long functions
- gocheckcompilerdirectives # validates go compiler directive comments (//go:)
- gochecksumtype # checks exhaustiveness on Go "sum types"
- gocognit # computes and checks the cognitive complexity of functions
- goconst # finds repeated strings that could be replaced by a constant
- gocritic # provides diagnostics that check for bugs, performance and style issues
- gocyclo # computes and checks the cyclomatic complexity of functions
- godot # checks if comments end in a period
- gomoddirectives # manages the use of 'replace', 'retract', and 'excludes' directives in go.mod
- goprintffuncname # checks that printf-like functions are named with f at the end
- gosec # inspects source code for security problems
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- iface # checks the incorrect use of interfaces, helping developers avoid interface pollution
- ineffassign # detects when assignments to existing variables are not used
- intrange # finds places where for loops could make use of an integer range
- loggercheck # checks key value pairs for common logger libraries (kitlog,klog,logr,zap)
- makezero # finds slice declarations with non-zero initial length
- mirror # reports wrong mirror patterns of bytes/strings usage
- mnd # detects magic numbers
- musttag # enforces field tags in (un)marshaled structs
- nakedret # finds naked returns in functions greater than a specified function length
- nestif # reports deeply nested if statements
- nilerr # finds the code that returns nil even if it checks that the error is not nil
- nilnesserr # reports that it checks for err != nil, but it returns a different nil value error (powered by nilness and nilerr)
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
- noctx # finds sending http request without context.Context
- nolintlint # reports ill-formed or insufficient nolint directives
- nonamedreturns # reports all named returns
- nosprintfhostport # checks for misuse of Sprintf to construct a host with port in a URL
- perfsprint # checks that fmt.Sprintf can be replaced with a faster alternative
- predeclared # finds code that shadows one of Go's predeclared identifiers
- promlinter # checks Prometheus metrics naming via promlint
- protogetter # reports direct reads from proto message fields when getters should be used
- reassign # checks that package variables are not reassigned
- recvcheck # checks for receiver type consistency
- revive # fast, configurable, extensible, flexible, and beautiful linter for Go, drop-in replacement of golint
- rowserrcheck # checks whether Err of rows is checked successfully
- sloglint # ensure consistent code style when using log/slog
- spancheck # checks for mistakes with OpenTelemetry/Census spans
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
- testableexamples # checks if examples are testable (have an expected output)
- testifylint # checks usage of github.com/stretchr/testify
- testpackage # makes you use a separate _test package
- tparallel # detects inappropriate usage of t.Parallel() method in your Go test codes
- unconvert # removes unnecessary type conversions
- unparam # reports unused function parameters
- unused # checks for unused constants, variables, functions and types
- usestdlibvars # detects the possibility to use variables/constants from the Go standard library
- usetesting # reports uses of functions with replacement inside the testing package
- wastedassign # finds wasted assignment statements
- whitespace # detects leading and trailing whitespace
- zerologlint # detects the wrong usage of zerolog that a user forgets to dispatch zerolog.Event
- asasalint
- asciicheck
- bidichk
- bodyclose
- canonicalheader
- copyloopvar
- cyclop
- depguard
- dupl
- durationcheck
- errcheck
- errname
- errorlint
- exhaustive
- exptostd
- fatcontext
- forbidigo
- funcorder
- funlen
- gocheckcompilerdirectives
- gochecksumtype
- gocognit
- goconst
- gocritic
- gocyclo
- godot
- gomoddirectives
- goprintffuncname
- gosec
- govet
- iface
- ineffassign
- intrange
- loggercheck
- makezero
- mirror
- mnd
- musttag
- nakedret
- nestif
- nilerr
- nilnesserr
- nilnil
- noctx
- nolintlint
- nonamedreturns
- nosprintfhostport
- perfsprint
- predeclared
- promlinter
- protogetter
- reassign
- recvcheck
- revive
- rowserrcheck
- sloglint
- spancheck
- sqlclosecheck
- staticcheck
- testableexamples
- testifylint
- tparallel
- unconvert
- unparam
- unused
- usestdlibvars
- usetesting
- wastedassign
- whitespace
- zerologlint
## should be evaluated
#- gochecknoglobals
#- gochecknoinits
#- testpackage
## you may want to enable
#- decorder # checks declaration order and count of types, constants, variables and functions
#- exhaustruct # [highly recommend to enable] checks if all structure fields are initialized
#- ginkgolinter # [if you use ginkgo/gomega] enforces standards of using ginkgo and gomega
#- godox # detects usage of FIXME, TODO and other keywords inside comments
#- goheader # checks is file header matches to pattern
#- inamedparam # [great idea, but too strict, need to ignore a lot of cases by default] reports interfaces with unnamed method parameters
#- interfacebloat # checks the number of methods inside an interface
#- ireturn # accept interfaces, return concrete types
#- prealloc # [premature optimization, but can be used in some cases] finds slice declarations that could potentially be preallocated
#- tagalign # checks that struct tags are well aligned
#- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope
#- wrapcheck # checks that errors returned from external packages are wrapped
#- gochecknoglobals # checks that no global variables exist
#- gochecknoinits # checks that no init functions are present in Go code
#- decorder
#- exhaustruct
#- ginkgolinter
#- godox
#- goheader
#- inamedparam
#- interfacebloat
#- ireturn
#- prealloc
#- tagalign
#- varnamelen
#- wrapcheck
## disabled
#- containedctx # detects struct contained context.Context field
#- contextcheck # [too many false positives] checks the function whether use a non-inherited context
#- dogsled # checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
#- dupword # [useless without config] checks for duplicate words in the source code
#- err113 # [too strict] checks the errors handling expressions
#- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted
#- forcetypeassert # [replaced by errcheck] finds forced type assertions
#- gomodguard # [use more powerful depguard] allow and block lists linter for direct Go module dependencies
#- gosmopolitan # reports certain i18n/l10n anti-patterns in your Go codebase
#- grouper # analyzes expression groups
#- importas # enforces consistent import aliases
#- lll # [replaced by golines] reports long lines
#- maintidx # measures the maintainability index of each function
#- misspell # [useless] finds commonly misspelled English words in comments
#- nlreturn # [too strict and mostly code is not more readable] checks for a new line before return and branch statements to increase code clarity
#- paralleltest # [too many false positives] detects missing usage of t.Parallel() method in your Go test
#- tagliatelle # checks the struct tags
#- thelper # detects golang test helpers without t.Helper() call and checks the consistency of test helpers
#- wsl # [too strict and mostly code is not more readable] whitespace linter forces you to use empty lines
#- containedctx
#- contextcheck
#- dogsled
#- dupword
#- err113
#- errchkjson
#- forcetypeassert
#- gomodguard
#- gosmopolitan
#- grouper
#- importas
#- lll
#- maintidx
#- misspell
#- nlreturn
#- paralleltest
#- tagliatelle
#- thelper
#- wsl
# All settings can be found here https://github.com/golangci/golangci-lint/blob/HEAD/.golangci.reference.yml
settings:
@@ -229,10 +229,8 @@ linters:
govet:
enable-all: true
disable:
- fieldalignment # too strict
settings:
shadow:
strict: true
- fieldalignment
- shadow
inamedparam:
skip-single-param: true
@@ -272,21 +270,7 @@ linters:
- github.com/jmoiron/sqlx
sloglint:
# Enforce not using global loggers.
# Values:
# - "": disabled
# - "all": report all global loggers
# - "default": report only the default slog logger
# https://github.com/go-simpler/sloglint?tab=readme-ov-file#no-global
# Default: ""
no-global: all
# Enforce using methods that accept a context.
# Values:
# - "": disabled
# - "all": report all contextless calls
# - "scope": report only if a context exists in the scope of the outermost function
# https://github.com/go-simpler/sloglint?tab=readme-ov-file#context-only
# Default: ""
context: scope
staticcheck:
@@ -309,6 +293,26 @@ linters:
disabled: true
- name: unused-parameter
disabled: true
- name: line-length-limit
disabled: true
- name: function-length
disabled: true
- name: max-public-structs
disabled: true
- name: import-shadowing
disabled: true
- name: cognitive-complexity
disabled: true
- name: nested-structs
disabled: true
- name: exported
disabled: true
- name: flag-parameter
disabled: true
gosec:
excludes:
- G402
exclusions:
warn-unused: false
+1 -1
View File
@@ -4,7 +4,7 @@ lint:
fmt:
golangci-lint fmt
mock:
mocks:
mockery --config .mockery.yaml
unit-test:
+2 -2
View File
@@ -11,8 +11,8 @@ var rootCmd = &cobra.Command{
Version: version.Version,
}
func Execute() {
rootCmd.Execute()
func Execute() error {
return rootCmd.Execute()
}
func init() {
+3 -1
View File
@@ -9,6 +9,8 @@ import (
"time"
)
const startupTimeout = 30 * time.Second
type PiHoleContainer struct {
Container tc.Container
password string
@@ -45,7 +47,7 @@ func (c *PiHoleContainer) EnvString(ssl bool) string {
func RunPiHole(password string) *PiHoleContainer {
logStrategy := wait.ForLog("listening on")
portStrategy := wait.ForListeningPort("80").WithStartupTimeout(30 * time.Second)
portStrategy := wait.ForListeningPort("80").WithStartupTimeout(startupTimeout)
containerReq := tc.GenericContainerRequest{
ContainerRequest: tc.ContainerRequest{
+10 -11
View File
@@ -2,7 +2,6 @@ package e2e
import (
"github.com/lovelaze/nebula-sync/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"testing"
)
@@ -29,9 +28,9 @@ func (suite *testSuite) Test_FullSync() {
suite.T().Setenv("RUN_GRAVITY", "true")
s, err := service.Init()
require.NoError(suite.T(), err)
suite.Require().NoError(err)
err = s.Run()
require.NoError(suite.T(), err)
suite.Require().NoError(err)
}
func (suite *testSuite) Test_FullSync_SSL() {
@@ -41,9 +40,9 @@ func (suite *testSuite) Test_FullSync_SSL() {
suite.T().Setenv("CLIENT_SKIP_TLS_VERIFICATION", "true")
s, err := service.Init()
require.NoError(suite.T(), err)
suite.Require().NoError(err)
err = s.Run()
require.NoError(suite.T(), err)
suite.Require().NoError(err)
}
func (suite *testSuite) Test_SelectiveSync() {
@@ -55,9 +54,9 @@ func (suite *testSuite) Test_SelectiveSync() {
setAllManualGravity(suite)
s, err := service.Init()
require.NoError(suite.T(), err)
suite.Require().NoError(err)
err = s.Run()
require.NoError(suite.T(), err)
suite.Require().NoError(err)
}
func (suite *testSuite) Test_SelectiveSync_Include() {
@@ -76,9 +75,9 @@ func (suite *testSuite) Test_SelectiveSync_Include() {
suite.T().Setenv("SYNC_CONFIG_DEBUG_INCLUDE", "database,networking")
s, err := service.Init()
require.NoError(suite.T(), err)
suite.Require().NoError(err)
err = s.Run()
require.NoError(suite.T(), err)
suite.Require().NoError(err)
}
func (suite *testSuite) Test_SelectiveSync_Exclude() {
@@ -97,9 +96,9 @@ func (suite *testSuite) Test_SelectiveSync_Exclude() {
suite.T().Setenv("SYNC_CONFIG_DEBUG_EXCLUDE", "database,networking")
s, err := service.Init()
require.NoError(suite.T(), err)
suite.Require().NoError(err)
err = s.Run()
require.NoError(suite.T(), err)
suite.Require().NoError(err)
}
func setAllManualConfig(suite *testSuite) {
+3 -3
View File
@@ -26,11 +26,11 @@ func (c *Config) loadClient() error {
return nil
}
func (settings *Client) NewHttpClient() *http.Client {
func (c *Client) NewHTTPClient() *http.Client {
return &http.Client{
Timeout: time.Duration(settings.Timeout) * time.Second,
Timeout: time.Duration(c.Timeout) * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: settings.SkipTLSVerification},
TLSClientConfig: &tls.Config{InsecureSkipVerify: c.SkipTLSVerification},
},
}
}
+1 -1
View File
@@ -16,7 +16,7 @@ func TestConfig_LoadClient(t *testing.T) {
err := conf.loadClient()
require.NoError(t, err)
assert.Equal(t, true, conf.Client.SkipTLSVerification)
assert.True(t, conf.Client.SkipTLSVerification)
assert.Equal(t, int64(45), conf.Client.Timeout)
assert.Equal(t, int64(5), conf.Client.RetryDelay)
}
+8 -14
View File
@@ -99,11 +99,8 @@ func (raw *RawConfigSettings) Validate() error {
if err := exclusive("misc", raw.MiscInclude, raw.MiscExclude); err != nil {
return err
}
if err := exclusive("debug", raw.DebugInclude, raw.DebugExclude); err != nil {
return err
}
return nil
return exclusive("debug", raw.DebugInclude, raw.DebugExclude)
}
func (raw *RawConfigSettings) Parse() (*ConfigSettings, error) {
@@ -144,11 +141,12 @@ func newConfigFilter(filterType filter.Type, keys []string) *ConfigFilter {
func NewConfigSetting(enabled bool, included, excluded []string) *ConfigSetting {
var configFilter *ConfigFilter
if included != nil {
switch {
case included != nil:
configFilter = newConfigFilter(filter.Include, included)
} else if excluded != nil {
case excluded != nil:
configFilter = newConfigFilter(filter.Exclude, excluded)
} else {
default:
configFilter = nil
}
@@ -171,11 +169,7 @@ func (c *Config) Load() error {
return err
}
if err := c.loadWebhookSettings(); err != nil {
return err
}
return nil
return c.loadWebhookSettings()
}
func (c *Config) loadSync() error {
@@ -192,7 +186,7 @@ func (c *Config) loadSync() error {
return nil
}
func (sync *Sync) loadConfigSettings() error {
func (s *Sync) loadConfigSettings() error {
raw := RawConfigSettings{}
if err := envconfig.Process("", &raw); err != nil {
@@ -204,7 +198,7 @@ func (sync *Sync) loadConfigSettings() error {
return err
}
sync.ConfigSettings = configSettings
s.ConfigSettings = configSettings
return nil
}
+39 -39
View File
@@ -18,12 +18,12 @@ func TestConfig_Load(t *testing.T) {
err := conf.Load()
require.NoError(t, err)
assert.Equal(t, "http://localhost:1337", conf.Primary.Url.String())
assert.Equal(t, "http://localhost:1337", conf.Primary.URL.String())
assert.Equal(t, "asdf", conf.Primary.Password)
assert.Len(t, conf.Replicas, 1)
assert.Equal(t, "http://localhost:1338", conf.Replicas[0].Url.String())
assert.Equal(t, "http://localhost:1338", conf.Replicas[0].URL.String())
assert.Equal(t, "qwerty", conf.Replicas[0].Password)
assert.Equal(t, false, conf.Sync.FullSync)
assert.False(t, conf.Sync.FullSync)
assert.Equal(t, "POST", conf.Sync.WebhookSettings.Success.Method)
assert.Equal(t, "POST", conf.Sync.WebhookSettings.Failure.Method)
}
@@ -56,9 +56,9 @@ func TestConfig_loadSync(t *testing.T) {
err := conf.loadSync()
require.NoError(t, err)
assert.Equal(t, true, conf.Sync.FullSync)
assert.True(t, conf.Sync.FullSync)
assert.Equal(t, "* * * * *", *conf.Sync.Cron)
assert.Equal(t, true, conf.Sync.RunGravity)
assert.True(t, conf.Sync.RunGravity)
assert.NotNil(t, conf.Sync.ConfigSettings)
assert.NotNil(t, conf.Sync.GravitySettings)
@@ -120,24 +120,24 @@ func TestRawConfig_Parse_Include(t *testing.T) {
t.Setenv("SYNC_CONFIG_DEBUG_INCLUDE", "key13,key14")
sync := Sync{}
assert.NoError(t, sync.loadConfigSettings())
require.NoError(t, sync.loadConfigSettings())
settings := sync.ConfigSettings
assert.Equal(t, settings.DNS.Filter.Type, filter.Include)
assert.Equal(t, settings.DNS.Filter.Keys, []string{"key1", "key2"})
assert.Equal(t, settings.DHCP.Filter.Type, filter.Include)
assert.Equal(t, settings.DHCP.Filter.Keys, []string{"key3", "key4"})
assert.Equal(t, settings.NTP.Filter.Type, filter.Include)
assert.Equal(t, settings.NTP.Filter.Keys, []string{"key5", "key6"})
assert.Equal(t, settings.Resolver.Filter.Type, filter.Include)
assert.Equal(t, settings.Resolver.Filter.Keys, []string{"key7", "key8"})
assert.Equal(t, settings.Database.Filter.Type, filter.Include)
assert.Equal(t, settings.Database.Filter.Keys, []string{"key9", "key10"})
assert.Equal(t, settings.Misc.Filter.Type, filter.Include)
assert.Equal(t, settings.Misc.Filter.Keys, []string{"key11", "key12"})
assert.Equal(t, settings.Debug.Filter.Type, filter.Include)
assert.Equal(t, settings.Debug.Filter.Keys, []string{"key13", "key14"})
assert.Equal(t, filter.Include, settings.DNS.Filter.Type)
assert.Equal(t, []string{"key1", "key2"}, settings.DNS.Filter.Keys)
assert.Equal(t, filter.Include, settings.DHCP.Filter.Type)
assert.Equal(t, []string{"key3", "key4"}, settings.DHCP.Filter.Keys)
assert.Equal(t, filter.Include, settings.NTP.Filter.Type)
assert.Equal(t, []string{"key5", "key6"}, settings.NTP.Filter.Keys)
assert.Equal(t, filter.Include, settings.Resolver.Filter.Type)
assert.Equal(t, []string{"key7", "key8"}, settings.Resolver.Filter.Keys)
assert.Equal(t, filter.Include, settings.Database.Filter.Type)
assert.Equal(t, []string{"key9", "key10"}, settings.Database.Filter.Keys)
assert.Equal(t, filter.Include, settings.Misc.Filter.Type)
assert.Equal(t, []string{"key11", "key12"}, settings.Misc.Filter.Keys)
assert.Equal(t, filter.Include, settings.Debug.Filter.Type)
assert.Equal(t, []string{"key13", "key14"}, settings.Debug.Filter.Keys)
}
func TestRawConfig_Parse_Exclude(t *testing.T) {
@@ -150,24 +150,24 @@ func TestRawConfig_Parse_Exclude(t *testing.T) {
t.Setenv("SYNC_CONFIG_DEBUG_EXCLUDE", "key13,key14")
sync := Sync{}
assert.NoError(t, sync.loadConfigSettings())
require.NoError(t, sync.loadConfigSettings())
settings := sync.ConfigSettings
assert.Equal(t, settings.DNS.Filter.Type, filter.Exclude)
assert.Equal(t, settings.DNS.Filter.Keys, []string{"key1", "key2"})
assert.Equal(t, settings.DHCP.Filter.Type, filter.Exclude)
assert.Equal(t, settings.DHCP.Filter.Keys, []string{"key3", "key4"})
assert.Equal(t, settings.NTP.Filter.Type, filter.Exclude)
assert.Equal(t, settings.NTP.Filter.Keys, []string{"key5", "key6"})
assert.Equal(t, settings.Resolver.Filter.Type, filter.Exclude)
assert.Equal(t, settings.Resolver.Filter.Keys, []string{"key7", "key8"})
assert.Equal(t, settings.Database.Filter.Type, filter.Exclude)
assert.Equal(t, settings.Database.Filter.Keys, []string{"key9", "key10"})
assert.Equal(t, settings.Misc.Filter.Type, filter.Exclude)
assert.Equal(t, settings.Misc.Filter.Keys, []string{"key11", "key12"})
assert.Equal(t, settings.Debug.Filter.Type, filter.Exclude)
assert.Equal(t, settings.Debug.Filter.Keys, []string{"key13", "key14"})
assert.Equal(t, filter.Exclude, settings.DNS.Filter.Type)
assert.Equal(t, []string{"key1", "key2"}, settings.DNS.Filter.Keys)
assert.Equal(t, filter.Exclude, settings.DHCP.Filter.Type)
assert.Equal(t, []string{"key3", "key4"}, settings.DHCP.Filter.Keys)
assert.Equal(t, filter.Exclude, settings.NTP.Filter.Type)
assert.Equal(t, []string{"key5", "key6"}, settings.NTP.Filter.Keys)
assert.Equal(t, filter.Exclude, settings.Resolver.Filter.Type)
assert.Equal(t, []string{"key7", "key8"}, settings.Resolver.Filter.Keys)
assert.Equal(t, filter.Exclude, settings.Database.Filter.Type)
assert.Equal(t, []string{"key9", "key10"}, settings.Database.Filter.Keys)
assert.Equal(t, filter.Exclude, settings.Misc.Filter.Type)
assert.Equal(t, []string{"key11", "key12"}, settings.Misc.Filter.Keys)
assert.Equal(t, filter.Exclude, settings.Debug.Filter.Type)
assert.Equal(t, []string{"key13", "key14"}, settings.Debug.Filter.Keys)
}
func TestConfig_NewConfigSetting(t *testing.T) {
@@ -182,12 +182,12 @@ func TestConfig_NewConfigSetting(t *testing.T) {
include := NewConfigSetting(true, []string{"key1", "key2"}, nil)
assert.True(t, include.Enabled)
assert.NotNil(t, include.Filter)
assert.Equal(t, include.Filter.Type, filter.Include)
assert.Equal(t, include.Filter.Keys, []string{"key1", "key2"})
assert.Equal(t, filter.Include, include.Filter.Type)
assert.Equal(t, []string{"key1", "key2"}, include.Filter.Keys)
exclude := NewConfigSetting(true, nil, []string{"key1", "key2"})
assert.True(t, exclude.Enabled)
assert.NotNil(t, exclude.Filter)
assert.Equal(t, exclude.Filter.Type, filter.Exclude)
assert.Equal(t, exclude.Filter.Keys, []string{"key1", "key2"})
assert.Equal(t, filter.Exclude, exclude.Filter.Type)
assert.Equal(t, []string{"key1", "key2"}, exclude.Filter.Keys)
}
+22 -20
View File
@@ -25,32 +25,34 @@ func (c *Config) loadTargets() error {
func loadPrimary() (*model.PiHole, error) {
env := "PRIMARY"
if value := os.Getenv(fmt.Sprintf("%s_FILE", env)); len(value) > 0 {
if bytes, err := os.ReadFile(value); err != nil {
if fileValue := os.Getenv(fmt.Sprintf("%s_FILE", env)); len(fileValue) > 0 {
bytes, err := os.ReadFile(fileValue)
if err != nil {
return nil, err
} else {
return parse(strings.TrimSpace(string(bytes)))
}
} else if value := os.Getenv(env); len(value) > 0 {
return parse(value)
} else {
return nil, fmt.Errorf("missing required env: %s/%s_FILE", env, env)
return parse(strings.TrimSpace(string(bytes)))
} else if envValue := os.Getenv(env); len(envValue) > 0 {
return parse(envValue)
}
return nil, fmt.Errorf("missing required env: %s/%s_FILE", env, env)
}
func loadReplicas() ([]model.PiHole, error) {
env := "REPLICAS"
if value := os.Getenv(fmt.Sprintf("%s_FILE", env)); len(value) > 0 {
if bytes, err := os.ReadFile(value); err != nil {
if fileValue := os.Getenv(fmt.Sprintf("%s_FILE", env)); len(fileValue) > 0 {
bytes, err := os.ReadFile(fileValue)
if err != nil {
return nil, err
} else {
return parseMultiple(strings.Split(strings.TrimSpace(string(bytes)), ","))
}
} else if value := os.Getenv(env); len(value) > 0 {
return parseMultiple(strings.Split(value, ","))
} else {
return nil, fmt.Errorf("missing required env: %s/%s_FILE", env, env)
return parseMultiple(strings.Split(strings.TrimSpace(string(bytes)), ","))
} else if envValue := os.Getenv(env); len(envValue) > 0 {
return parseMultiple(strings.Split(envValue, ","))
}
return nil, fmt.Errorf("missing required env: %s/%s_FILE", env, env)
}
func parse(value string) (*model.PiHole, error) {
@@ -64,12 +66,12 @@ func parse(value string) (*model.PiHole, error) {
func parseMultiple(values []string) ([]model.PiHole, error) {
replicas := []model.PiHole{}
for _, value := range values {
if ph, err := parse(value); err != nil {
ph, err := parse(value)
if err != nil {
return nil, err
} else {
replicas = append(replicas, *ph)
}
replicas = append(replicas, *ph)
}
return replicas, nil
}
+7 -7
View File
@@ -18,12 +18,12 @@ func TestConfig_Load_Target(t *testing.T) {
err := conf.loadTargets()
require.NoError(t, err)
assert.Equal(t, "http://localhost:1337", conf.Primary.Url.String())
assert.Equal(t, "http://localhost:1337", conf.Primary.URL.String())
assert.Equal(t, "asdf", conf.Primary.Password)
assert.Len(t, conf.Replicas, 2)
assert.Equal(t, "http://localhost:1338", conf.Replicas[0].Url.String())
assert.Equal(t, "http://localhost:1338", conf.Replicas[0].URL.String())
assert.Equal(t, "qwerty", conf.Replicas[0].Password)
assert.Equal(t, "http://localhost:1339", conf.Replicas[1].Url.String())
assert.Equal(t, "http://localhost:1339", conf.Replicas[1].URL.String())
assert.Equal(t, "foobar", conf.Replicas[1].Password)
}
@@ -38,12 +38,12 @@ func TestConfig_Load_TargetFiles(t *testing.T) {
err := conf.loadTargets()
require.NoError(t, err)
assert.Equal(t, "https://ph1.example.com", conf.Primary.Url.String())
assert.Equal(t, "https://ph1.example.com", conf.Primary.URL.String())
assert.Equal(t, "password1", conf.Primary.Password)
assert.Len(t, conf.Replicas, 2)
assert.Equal(t, "https://ph2.example.com", conf.Replicas[0].Url.String())
assert.Equal(t, "https://ph2.example.com", conf.Replicas[0].URL.String())
assert.Equal(t, "password2", conf.Replicas[0].Password)
assert.Equal(t, "https://ph3.example.com", conf.Replicas[1].Url.String())
assert.Equal(t, "https://ph3.example.com", conf.Replicas[1].URL.String())
assert.Equal(t, "password3", conf.Replicas[1].Password)
}
@@ -66,7 +66,7 @@ func TestConfig_Load_NoReplicas(t *testing.T) {
require.Empty(t, os.Getenv("PRIMARY"))
require.Empty(t, os.Getenv("REPLICAS"))
require.Empty(t, os.Getenv("REPLICAS_FILE"))
err := conf.loadTargets()
assert.Error(t, err)
}
+1 -1
View File
@@ -20,7 +20,7 @@ type WebhookRequest struct {
Body string `envconfig:"BODY"`
Headers map[string]string `envconfig:"HEADERS"`
Method string `default:"POST" envconfig:"METHOD"`
Url string `envconfig:"URL"`
URL string `envconfig:"URL"`
}
func (c *Config) loadWebhookSettings() error {
+6 -6
View File
@@ -20,9 +20,9 @@ func TestWebhookSettings_Load_Success(t *testing.T) {
require.NoError(t, err)
success := conf.Sync.WebhookSettings.Success
assert.Equal(t, "http://success.example.com", success.Url)
assert.Equal(t, "http://success.example.com", success.URL)
assert.Equal(t, "POST", success.Method)
assert.Equal(t, "{\"status\":\"ok\"}", success.Body)
assert.JSONEq(t, `{"status":"ok"}`, success.Body)
assert.Equal(t, map[string]string{
"Content-Type": "application/json",
"Authorization": "Bearer token",
@@ -43,9 +43,9 @@ func TestWebhookSettings_Load_Failure(t *testing.T) {
require.NoError(t, err)
failure := conf.Sync.WebhookSettings.Failure
assert.Equal(t, "http://failure.example.com", failure.Url)
assert.Equal(t, "http://failure.example.com", failure.URL)
assert.Equal(t, "PUT", failure.Method)
assert.Equal(t, "{\"status\":\"error\"}", failure.Body)
assert.JSONEq(t, `{"status":"error"}`, failure.Body)
assert.Equal(t, map[string]string{
"Content-Type": "application/json",
}, failure.Headers)
@@ -89,8 +89,8 @@ func TestWebhookSettings_EmptyURLs(t *testing.T) {
err := conf.loadWebhookSettings()
require.NoError(t, err)
assert.Empty(t, conf.Sync.WebhookSettings.Success.Url)
assert.Empty(t, conf.Sync.WebhookSettings.Failure.Url)
assert.Empty(t, conf.Sync.WebhookSettings.Success.URL)
assert.Empty(t, conf.Sync.WebhookSettings.Failure.URL)
}
func TestWebhookSettings_ClientConfiguration(t *testing.T) {
+1
View File
@@ -57,5 +57,6 @@ func Init() {
}
}
//nolint:reassign //not passed around
log.Logger = logger
}
+11 -66
View File
@@ -36,12 +36,12 @@ func (_m *Client) EXPECT() *Client_Expecter {
return &Client_Expecter{mock: &_m.Mock}
}
// ApiPath provides a mock function for the type Client
func (_mock *Client) ApiPath(target string) string {
// APIPath provides a mock function for the type Client
func (_mock *Client) APIPath(target string) string {
ret := _mock.Called(target)
if len(ret) == 0 {
panic("no return value specified for ApiPath")
panic("no return value specified for APIPath")
}
var r0 string
@@ -53,30 +53,30 @@ func (_mock *Client) ApiPath(target string) string {
return r0
}
// Client_ApiPath_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ApiPath'
type Client_ApiPath_Call struct {
// Client_APIPath_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'APIPath'
type Client_APIPath_Call struct {
*mock.Call
}
// ApiPath is a helper method to define mock.On call
// APIPath is a helper method to define mock.On call
// - target
func (_e *Client_Expecter) ApiPath(target interface{}) *Client_ApiPath_Call {
return &Client_ApiPath_Call{Call: _e.mock.On("ApiPath", target)}
func (_e *Client_Expecter) APIPath(target interface{}) *Client_APIPath_Call {
return &Client_APIPath_Call{Call: _e.mock.On("APIPath", target)}
}
func (_c *Client_ApiPath_Call) Run(run func(target string)) *Client_ApiPath_Call {
func (_c *Client_APIPath_Call) Run(run func(target string)) *Client_APIPath_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *Client_ApiPath_Call) Return(s string) *Client_ApiPath_Call {
func (_c *Client_APIPath_Call) Return(s string) *Client_APIPath_Call {
_c.Call.Return(s)
return _c
}
func (_c *Client_ApiPath_Call) RunAndReturn(run func(target string) string) *Client_ApiPath_Call {
func (_c *Client_APIPath_Call) RunAndReturn(run func(target string) string) *Client_APIPath_Call {
_c.Call.Return(run)
return _c
}
@@ -235,61 +235,6 @@ func (_c *Client_GetTeleporter_Call) RunAndReturn(run func() ([]byte, error)) *C
return _c
}
// GetVersion provides a mock function for the type Client
func (_mock *Client) GetVersion() (*model.VersionResponse, error) {
ret := _mock.Called()
if len(ret) == 0 {
panic("no return value specified for GetVersion")
}
var r0 *model.VersionResponse
var r1 error
if returnFunc, ok := ret.Get(0).(func() (*model.VersionResponse, error)); ok {
return returnFunc()
}
if returnFunc, ok := ret.Get(0).(func() *model.VersionResponse); ok {
r0 = returnFunc()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.VersionResponse)
}
}
if returnFunc, ok := ret.Get(1).(func() error); ok {
r1 = returnFunc()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Client_GetVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetVersion'
type Client_GetVersion_Call struct {
*mock.Call
}
// GetVersion is a helper method to define mock.On call
func (_e *Client_Expecter) GetVersion() *Client_GetVersion_Call {
return &Client_GetVersion_Call{Call: _e.mock.On("GetVersion")}
}
func (_c *Client_GetVersion_Call) Run(run func()) *Client_GetVersion_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *Client_GetVersion_Call) Return(versionResponse *model.VersionResponse, err error) *Client_GetVersion_Call {
_c.Call.Return(versionResponse, err)
return _c
}
func (_c *Client_GetVersion_Call) RunAndReturn(run func() (*model.VersionResponse, error)) *Client_GetVersion_Call {
_c.Call.Return(run)
return _c
}
// PatchConfig provides a mock function for the type Client
func (_mock *Client) PatchConfig(patchRequest *model.PatchConfigRequest) error {
ret := _mock.Called(patchRequest)
+51 -77
View File
@@ -2,6 +2,7 @@ package pihole
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -18,26 +19,25 @@ var (
userAgent = fmt.Sprintf("nebula-sync/%s", version.Version)
)
func NewClient(piHole model.PiHole, httpClient *http.Client) Client {
logger := log.With().Str("client", piHole.Url.String()).Logger()
return &client{
piHole: piHole,
logger: &logger,
httpClient: httpClient,
}
}
type Client interface {
PostAuth() error
DeleteSession() error
GetVersion() (*model.VersionResponse, error)
GetTeleporter() ([]byte, error)
PostTeleporter(payload []byte, teleporterRequest *model.PostTeleporterRequest) error
GetConfig() (configResponse *model.ConfigResponse, err error)
PatchConfig(patchRequest *model.PatchConfigRequest) error
PostRunGravity() error
String() string
ApiPath(target string) string
APIPath(target string) string
}
func NewClient(piHole model.PiHole, httpClient *http.Client) Client {
logger := log.With().Str("client", piHole.URL.String()).Logger()
return &client{
piHole: piHole,
logger: &logger,
httpClient: httpClient,
}
}
type client struct {
@@ -71,7 +71,7 @@ func (client *client) PostAuth() error {
return client.wrapError(err, nil)
}
req, err := http.NewRequest("POST", client.ApiPath("/auth"), bytes.NewReader(reqBytes))
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, client.APIPath("/auth"), bytes.NewReader(reqBytes))
if err != nil {
return client.wrapError(err, req)
@@ -84,8 +84,9 @@ func (client *client) PostAuth() error {
if err != nil {
return client.wrapError(err, req)
}
defer response.Body.Close()
if err := successfulHttpStatus(response.StatusCode); err != nil {
if err := successfulHTTPStatus(response.StatusCode); err != nil {
return client.wrapError(err, req)
}
@@ -119,78 +120,46 @@ func (client *client) DeleteSession() error {
return nil
}
req, err := http.NewRequest("DELETE", client.ApiPath("auth"), nil)
req, err := http.NewRequestWithContext(context.Background(), http.MethodDelete, client.APIPath("auth"), nil)
if err != nil {
return client.wrapError(err, req)
}
req.Header.Set("sid", client.auth.sid)
req.Header.Set("Sid", client.auth.sid)
req.Header.Set("User-Agent", userAgent)
response, err := client.httpClient.Do(req)
if err != nil {
return client.wrapError(err, req)
}
defer response.Body.Close()
if err := successfulHttpStatus(response.StatusCode); err != nil {
if err := successfulHTTPStatus(response.StatusCode); err != nil {
return client.wrapError(err, req)
}
return client.wrapError(err, req)
}
func (client *client) GetVersion() (*model.VersionResponse, error) {
client.logger.Debug().Msg("Get version")
versionResponse := model.VersionResponse{}
if err := client.auth.verify(); err != nil {
return &versionResponse, client.wrapError(err, nil)
}
req, err := http.NewRequest("GET", client.ApiPath("info/version"), nil)
if err != nil {
return &versionResponse, client.wrapError(err, req)
}
req.Header.Set("sid", client.auth.sid)
req.Header.Set("User-Agent", userAgent)
response, err := client.httpClient.Do(req)
if err != nil {
return &versionResponse, client.wrapError(err, req)
}
if err := successfulHttpStatus(response.StatusCode); err != nil {
return &versionResponse, client.wrapError(err, req)
}
body, err := io.ReadAll(response.Body)
if err != nil {
return &versionResponse, client.wrapError(err, req)
}
err = json.Unmarshal(body, &versionResponse)
return &versionResponse, client.wrapError(err, req)
}
func (client *client) GetTeleporter() ([]byte, error) {
client.logger.Debug().Msg("Get teleporter")
if err := client.auth.verify(); err != nil {
return nil, client.wrapError(err, nil)
}
req, err := http.NewRequest("GET", client.ApiPath("teleporter"), nil)
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, client.APIPath("teleporter"), nil)
if err != nil {
return nil, client.wrapError(err, req)
}
req.Header.Set("sid", client.auth.sid)
req.Header.Set("Sid", client.auth.sid)
req.Header.Set("User-Agent", userAgent)
response, err := client.httpClient.Do(req)
if err != nil {
return nil, client.wrapError(err, req)
}
defer response.Body.Close()
if err := successfulHttpStatus(response.StatusCode); err != nil {
if err := successfulHTTPStatus(response.StatusCode); err != nil {
return nil, client.wrapError(err, req)
}
@@ -227,11 +196,11 @@ func (client *client) PostTeleporter(payload []byte, teleporterRequest *model.Po
return client.wrapError(err, nil)
}
req, err := http.NewRequest("POST", client.ApiPath("teleporter"), &requestBody)
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, client.APIPath("teleporter"), &requestBody)
if err != nil {
return client.wrapError(err, req)
}
req.Header.Set("sid", client.auth.sid)
req.Header.Set("Sid", client.auth.sid)
req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("User-Agent", userAgent)
@@ -239,46 +208,49 @@ func (client *client) PostTeleporter(payload []byte, teleporterRequest *model.Po
if err != nil {
return client.wrapError(err, req)
}
defer response.Body.Close()
if err := successfulHttpStatus(response.StatusCode); err != nil {
if err := successfulHTTPStatus(response.StatusCode); err != nil {
return client.wrapError(err, req)
}
return nil
}
func (client *client) GetConfig() (configResponse *model.ConfigResponse, err error) {
func (client *client) GetConfig() (*model.ConfigResponse, error) {
var configResponse model.ConfigResponse
client.logger.Debug().Msg("Get config")
if err := client.auth.verify(); err != nil {
return configResponse, client.wrapError(err, nil)
return &configResponse, client.wrapError(err, nil)
}
req, err := http.NewRequest("GET", client.ApiPath("config"), nil)
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, client.APIPath("config"), nil)
if err != nil {
return configResponse, client.wrapError(err, req)
return &configResponse, client.wrapError(err, req)
}
req.Header.Set("sid", client.auth.sid)
req.Header.Set("Sid", client.auth.sid)
req.Header.Set("User-Agent", userAgent)
response, err := client.httpClient.Do(req)
if err != nil {
return configResponse, client.wrapError(err, req)
return &configResponse, client.wrapError(err, req)
}
defer response.Body.Close()
if err := successfulHttpStatus(response.StatusCode); err != nil {
return configResponse, client.wrapError(err, req)
if err := successfulHTTPStatus(response.StatusCode); err != nil {
return &configResponse, client.wrapError(err, req)
}
body, err := io.ReadAll(response.Body)
if err != nil {
return configResponse, client.wrapError(err, req)
return &configResponse, client.wrapError(err, req)
}
if err := json.Unmarshal(body, &configResponse); err != nil {
return configResponse, client.wrapError(err, req)
return &configResponse, client.wrapError(err, req)
}
return configResponse, client.wrapError(err, req)
return &configResponse, client.wrapError(err, req)
}
func (client *client) PatchConfig(patchRequest *model.PatchConfigRequest) error {
@@ -292,19 +264,20 @@ func (client *client) PatchConfig(patchRequest *model.PatchConfigRequest) error
return client.wrapError(err, nil)
}
req, err := http.NewRequest("PATCH", client.ApiPath("config"), bytes.NewReader(reqBytes))
req, err := http.NewRequestWithContext(context.Background(), http.MethodPatch, client.APIPath("config"), bytes.NewReader(reqBytes))
if err != nil {
return client.wrapError(err, req)
}
req.Header.Set("sid", client.auth.sid)
req.Header.Set("Sid", client.auth.sid)
req.Header.Set("User-Agent", userAgent)
response, err := client.httpClient.Do(req)
if err != nil {
return client.wrapError(err, req)
}
defer response.Body.Close()
if err := successfulHttpStatus(response.StatusCode); err != nil {
if err := successfulHTTPStatus(response.StatusCode); err != nil {
return client.wrapError(err, req)
}
@@ -317,19 +290,20 @@ func (client *client) PostRunGravity() error {
return client.wrapError(err, nil)
}
req, err := http.NewRequest("POST", client.ApiPath("action/gravity"), nil)
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, client.APIPath("action/gravity"), nil)
if err != nil {
return client.wrapError(err, req)
}
req.Header.Set("sid", client.auth.sid)
req.Header.Set("Sid", client.auth.sid)
req.Header.Set("User-Agent", userAgent)
response, err := client.httpClient.Do(req)
if err != nil {
return client.wrapError(err, req)
}
defer response.Body.Close()
if err := successfulHttpStatus(response.StatusCode); err != nil {
if err := successfulHTTPStatus(response.StatusCode); err != nil {
return client.wrapError(err, req)
}
@@ -337,11 +311,11 @@ func (client *client) PostRunGravity() error {
}
func (client *client) String() string {
return client.piHole.Url.String()
return client.piHole.URL.String()
}
func (client *client) ApiPath(target string) string {
return client.piHole.Url.JoinPath("api", target).String()
func (client *client) APIPath(target string) string {
return client.piHole.URL.JoinPath("api", target).String()
}
func (client *client) wrapError(err error, req *http.Request) error {
@@ -354,7 +328,7 @@ func (client *client) wrapError(err error, req *http.Request) error {
return nil
}
func successfulHttpStatus(statusCode int) error {
func successfulHTTPStatus(statusCode int) error {
if statusCode >= 200 && statusCode <= 299 {
return nil
}
+13 -20
View File
@@ -30,7 +30,7 @@ type clientTestSuite struct {
func (suite *clientTestSuite) SetupTest() {
client := createClient(piHole)
err := client.PostAuth()
require.NoError(suite.T(), err)
suite.Require().NoError(err)
suite.client = client
}
@@ -41,27 +41,20 @@ func TestClientIntegration(t *testing.T) {
func (suite *clientTestSuite) TestClient_Authenticate() {
err := suite.client.PostAuth()
assert.NoError(suite.T(), err)
suite.Require().NoError(err)
}
func (suite *clientTestSuite) TestClient_DeleteSession() {
err := suite.client.DeleteSession()
assert.NoError(suite.T(), err)
}
func (suite *clientTestSuite) TestClient_GetVersion() {
version, err := suite.client.GetVersion()
assert.NoError(suite.T(), err)
assert.NotNil(suite.T(), version)
suite.Require().NoError(err)
}
func (suite *clientTestSuite) TestClient_GetTeleporter() {
payload, err := suite.client.GetTeleporter()
assert.NoError(suite.T(), err)
assert.NotNil(suite.T(), payload)
suite.Require().NoError(err)
suite.NotNil(suite.T(), payload)
}
func (suite *clientTestSuite) TestClient_PostTeleporter() {
@@ -80,14 +73,14 @@ func (suite *clientTestSuite) TestClient_PostTeleporter() {
},
})
assert.NoError(suite.T(), err)
suite.Require().NoError(err)
}
func (suite *clientTestSuite) TestClient_GetConfig() {
conf, err := suite.client.GetConfig()
assert.NoError(suite.T(), err)
assert.NotNil(suite.T(), conf)
suite.Require().NoError(err)
suite.NotNil(suite.T(), conf)
}
func (suite *clientTestSuite) TestClient_PatchConfig() {
@@ -103,13 +96,13 @@ func (suite *clientTestSuite) TestClient_PatchConfig() {
}}
err := suite.client.PatchConfig(&request)
assert.NoError(suite.T(), err)
suite.Require().NoError(err)
}
func (suite *clientTestSuite) TestClient_PostRunGravity() {
err := suite.client.PostRunGravity()
assert.NoError(suite.T(), err)
suite.Require().NoError(err)
}
func TestClient_String(t *testing.T) {
@@ -124,7 +117,7 @@ func TestClient_ApiPath(t *testing.T) {
c := NewClient(piHole, httpClient)
url := c.String()
path := c.ApiPath("testing")
path := c.APIPath("testing")
expectedPath := fmt.Sprintf("%s/api/testing", url)
assert.Equal(t, expectedPath, path)
@@ -137,10 +130,10 @@ func Test_auth_verify(t *testing.T) {
validity: 0,
valid: false,
}
assert.Error(t, a.verify())
require.Error(t, a.verify())
a.valid = true
assert.NoError(t, a.verify())
require.NoError(t, a.verify())
}
func createClient(container tc.Container) Client {
+13 -12
View File
@@ -1,6 +1,7 @@
package model
import (
"errors"
"fmt"
"github.com/rs/zerolog/log"
"net/url"
@@ -8,14 +9,10 @@ import (
)
type PiHole struct {
Url *url.URL
URL *url.URL
Password string
}
func (ph PiHole) String() string {
return fmt.Sprintf("{Url:%s}", ph.Url)
}
func NewPiHole(host, password string) PiHole {
u, err := url.Parse(host)
if err != nil {
@@ -23,26 +20,30 @@ func NewPiHole(host, password string) PiHole {
}
return PiHole{
Url: u,
URL: u,
Password: password,
}
}
func (piHole *PiHole) Decode(value string) error {
func (ph *PiHole) String() string {
return fmt.Sprintf("{URL:%s}", ph.URL)
}
func (ph *PiHole) Decode(value string) error {
uri, password, found := strings.Cut(value, "|")
if !found {
return fmt.Errorf("invalid pihole format")
return errors.New("invalid pihole format")
}
parsedUrl, err := url.Parse(uri)
parsedURL, err := url.Parse(uri)
if err != nil {
return fmt.Errorf("parse url: %s", err)
return fmt.Errorf("parse url: %w", err)
}
*piHole = PiHole{
Url: parsedUrl,
*ph = PiHole{
URL: parsedURL,
Password: password,
}
return nil
+5 -4
View File
@@ -3,6 +3,7 @@ package model
import (
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"net/url"
"testing"
)
@@ -13,11 +14,11 @@ func TestPiHole_Decode(t *testing.T) {
const pw = "asdfa|sdf"
err := ph.Decode(fmt.Sprintf("%s|%s", uri, pw))
assert.NoError(t, err)
require.NoError(t, err)
expectedUrl, err := url.Parse(uri)
assert.NoError(t, err)
expectedURL, err := url.Parse(uri)
require.NoError(t, err)
assert.Equal(t, expectedUrl, ph.Url)
assert.Equal(t, expectedURL, ph.URL)
assert.Equal(t, pw, ph.Password)
}
+7 -7
View File
@@ -21,13 +21,13 @@ type PostTeleporterRequest struct {
}
type PatchConfig struct {
DNS map[string]interface{} `json:"dns"`
DHCP map[string]interface{} `json:"dhcp"`
NTP map[string]interface{} `json:"ntp"`
Resolver map[string]interface{} `json:"resolver"`
Database map[string]interface{} `json:"database"`
Misc map[string]interface{} `json:"misc"`
Debug map[string]interface{} `json:"debug"`
DNS map[string]any `json:"dns"`
DHCP map[string]any `json:"dhcp"`
NTP map[string]any `json:"ntp"`
Resolver map[string]any `json:"resolver"`
Database map[string]any `json:"database"`
Misc map[string]any `json:"misc"`
Debug map[string]any `json:"debug"`
}
type PatchConfigRequest struct {
+9 -47
View File
@@ -16,59 +16,21 @@ type AuthResponse struct {
} `json:"session"`
}
type VersionResponse struct {
Version struct {
Core struct {
Local struct {
Version string `json:"version"`
Branch string `json:"branch"`
Hash string `json:"hash"`
} `json:"local"`
Remote struct {
Version interface{} `json:"version"`
Hash string `json:"hash"`
} `json:"remote"`
} `json:"core"`
Web struct {
Local struct {
Version string `json:"version"`
Branch string `json:"branch"`
Hash string `json:"hash"`
} `json:"local"`
Remote struct {
Version interface{} `json:"version"`
Hash string `json:"hash"`
} `json:"remote"`
} `json:"web"`
Ftl struct {
Local struct {
Hash string `json:"hash"`
Branch string `json:"branch"`
Version string `json:"version"`
Date string `json:"date"`
} `json:"local"`
Remote struct {
Version interface{} `json:"version"`
Hash string `json:"hash"`
} `json:"remote"`
} `json:"ftl"`
Docker struct {
Local string `json:"local"`
Remote string `json:"remote"`
} `json:"docker"`
} `json:"version"`
Took float64 `json:"took"`
}
type ConfigResponse struct {
Config map[string]interface{} `json:"config"`
Config map[string]any `json:"config"`
}
func (c *ConfigResponse) Get(key string) map[string]interface{} {
func (c *ConfigResponse) Get(key string) map[string]any {
value, exists := c.Config[key]
if !exists {
log.Warn().Msg(fmt.Sprintf("Missing key (%s) in config response", key))
return nil
}
return value.(map[string]interface{})
extracted, ok := value.(map[string]any)
if !ok {
log.Warn().Msg(fmt.Sprintf("Received unexpected type for key (%s) in config response", key))
return nil
}
return extracted
}
+11 -7
View File
@@ -33,7 +33,7 @@ func Init() (*Service, error) {
return nil, err
}
httpClient := conf.Client.NewHttpClient()
httpClient := conf.Client.NewHTTPClient()
retry.Init(conf.Client)
primary := pihole.NewClient(conf.Primary, httpClient)
@@ -68,12 +68,8 @@ func (service *Service) Run() error {
return nil
}
func (service *Service) doSync(t sync.Target) (err error) {
if service.conf.Sync.FullSync {
err = t.FullSync(service.conf.Sync)
} else {
err = t.SelectiveSync(service.conf.Sync)
}
func (service *Service) doSync(t sync.Target) error {
err := service.selectSyncMethod(t)
if err != nil {
for _, callback := range service.callbacks {
@@ -89,6 +85,14 @@ func (service *Service) doSync(t sync.Target) (err error) {
return err
}
func (service *Service) selectSyncMethod(t sync.Target) error {
if service.conf.Sync.FullSync {
return t.FullSync(service.conf.Sync)
}
return t.SelectiveSync(service.conf.Sync)
}
func (service *Service) startCron(cmd func()) error {
cron := cron.New()
+14 -14
View File
@@ -24,7 +24,7 @@ func (ft Type) String() string {
return s
}
func ByType(filter Type, keys []string, json map[string]interface{}) (map[string]interface{}, error) {
func ByType(filter Type, keys []string, json map[string]any) (map[string]any, error) {
switch filter {
case Include:
return includeKeys(json, keys), nil
@@ -35,8 +35,8 @@ func ByType(filter Type, keys []string, json map[string]interface{}) (map[string
}
}
func includeKeys(jsonData map[string]interface{}, keys []string) map[string]interface{} {
result := make(map[string]interface{})
func includeKeys(jsonData map[string]any, keys []string) map[string]any {
result := make(map[string]any)
for _, key := range keys {
value := getNestedValue(jsonData, key)
@@ -50,7 +50,7 @@ func includeKeys(jsonData map[string]interface{}, keys []string) map[string]inte
return result
}
func excludeKeys(jsonData map[string]interface{}, keys []string) map[string]interface{} {
func excludeKeys(jsonData map[string]any, keys []string) map[string]any {
result := deepCopy(jsonData)
for _, key := range keys {
@@ -60,11 +60,11 @@ func excludeKeys(jsonData map[string]interface{}, keys []string) map[string]inte
return result
}
func getNestedValue(data map[string]interface{}, key string) interface{} {
func getNestedValue(data map[string]any, key string) any {
keys := strings.Split(key, ".")
current := data
for i, k := range keys {
if next, ok := current[k].(map[string]interface{}); ok {
if next, ok := current[k].(map[string]any); ok {
current = next
if i == len(keys)-1 {
return next
@@ -78,15 +78,15 @@ func getNestedValue(data map[string]interface{}, key string) interface{} {
return current
}
func setNestedValue(target map[string]interface{}, key string, value interface{}) {
func setNestedValue(target map[string]any, key string, value any) {
keys := strings.Split(key, ".")
current := target
for _, k := range keys[:len(keys)-1] {
if _, exists := current[k]; !exists {
current[k] = make(map[string]interface{})
current[k] = make(map[string]any)
}
if next, ok := current[k].(map[string]interface{}); ok {
if next, ok := current[k].(map[string]any); ok {
current = next
}
}
@@ -95,7 +95,7 @@ func setNestedValue(target map[string]interface{}, key string, value interface{}
current[lastKey] = value
}
func removeNestedKey(target map[string]interface{}, keys []string) {
func removeNestedKey(target map[string]any, keys []string) {
if len(keys) == 0 {
return
}
@@ -114,7 +114,7 @@ func removeNestedKey(target map[string]interface{}, keys []string) {
return
}
if nested, exists := target[currentKey].(map[string]interface{}); exists {
if nested, exists := target[currentKey].(map[string]any); exists {
removeNestedKey(nested, remainingKeys)
if len(nested) == 0 {
delete(target, currentKey)
@@ -122,11 +122,11 @@ func removeNestedKey(target map[string]interface{}, keys []string) {
}
}
func deepCopy(original map[string]interface{}) map[string]interface{} {
copied := make(map[string]interface{})
func deepCopy(original map[string]any) map[string]any {
copied := make(map[string]any)
for key, value := range original {
switch v := value.(type) {
case map[string]interface{}:
case map[string]any:
copied[key] = deepCopy(v)
default:
copied[key] = v
+32 -27
View File
@@ -3,6 +3,7 @@ package filter
import (
"encoding/json"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"maps"
"os"
"slices"
@@ -11,10 +12,11 @@ import (
func TestFilter_ByType_Include(t *testing.T) {
filterKeys := []string{"cache", "upstreams", "interface"}
data := loadDnsData()
data := loadDNSData()
result, err := ByType(Include, filterKeys, data)
assert.NoError(t, err)
assert.Equal(t, len(result), len(filterKeys))
require.NoError(t, err)
assert.Len(t, filterKeys, 3)
assert.Len(t, result, 3)
for key := range maps.Keys(data) {
if slices.Contains(filterKeys, key) {
@@ -28,9 +30,9 @@ func TestFilter_ByType_Include(t *testing.T) {
func TestFilter_ByType_Exclude(t *testing.T) {
filterKeys := []string{"cache", "upstreams", "interface"}
data := loadDnsData()
data := loadDNSData()
result, err := ByType(Exclude, filterKeys, data)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, len(result), len(data)-len(filterKeys))
for key := range maps.Keys(data) {
@@ -45,28 +47,31 @@ func TestFilter_ByType_Exclude(t *testing.T) {
func TestFilter_ByType_MultipleNested(t *testing.T) {
filterKeys := []string{"reply.host.force4", "reply.host.IPv4", "reply.blocking.force4"}
data := loadDnsData()
data := loadDNSData()
result, err := ByType(Include, filterKeys, data)
assert.NoError(t, err)
assert.Equal(t, len(result), 1)
require.NoError(t, err)
assert.Len(t, result, 1)
reply := result["reply"].(map[string]interface{})
host := reply["host"].(map[string]interface{})
blocking := reply["blocking"].(map[string]interface{})
reply, ok := result["reply"].(map[string]any)
assert.True(t, ok)
host, ok := reply["host"].(map[string]any)
assert.True(t, ok)
blocking, ok := reply["blocking"].(map[string]any)
assert.True(t, ok)
assert.Equal(t, len(reply), 2)
assert.Equal(t, len(host), 2)
assert.Equal(t, len(blocking), 1)
assert.NotEqual(t, data["reply"].(map[string]interface{}), reply)
assert.Len(t, reply, 2)
assert.Len(t, host, 2)
assert.Len(t, blocking, 1)
assert.NotEqual(t, data["reply"].(map[string]any), reply)
}
func loadDnsData() map[string]interface{} {
func loadDNSData() map[string]any {
file, err := os.ReadFile("../../../testdata/dns.json")
if err != nil {
panic("failed to read testdata")
}
var data map[string]interface{}
var data map[string]any
if err := json.Unmarshal(file, &data); err != nil {
panic("failed to unmarshal testdata")
}
@@ -75,9 +80,9 @@ func loadDnsData() map[string]interface{} {
}
func TestFilter_IncludeKeys(t *testing.T) {
data := map[string]interface{}{
data := map[string]any{
"a": 1,
"b": map[string]interface{}{"c": 2, "d": 3},
"b": map[string]any{"c": 2, "d": 3},
"e": 4,
}
@@ -85,14 +90,14 @@ func TestFilter_IncludeKeys(t *testing.T) {
result := includeKeys(data, keys)
assert.Equal(t, 1, result["a"])
assert.Equal(t, 2, result["b"].(map[string]interface{})["c"])
assert.Equal(t, nil, result["b"].(map[string]interface{})["d"])
assert.Equal(t, 2, result["b"].(map[string]any)["c"])
assert.Nil(t, result["b"].(map[string]any)["d"])
assert.Equal(t, 4, result["e"])
assert.Len(t, result, 3)
}
func TestFilter_IncludeKeys_MissingKey(t *testing.T) {
data := map[string]interface{}{"a": 1}
data := map[string]any{"a": 1}
keys := []string{"b"}
result := includeKeys(data, keys)
@@ -100,9 +105,9 @@ func TestFilter_IncludeKeys_MissingKey(t *testing.T) {
}
func TestFilter_ExcludeKeys(t *testing.T) {
data := map[string]interface{}{
data := map[string]any{
"a": 1,
"b": map[string]interface{}{"c": 2, "d": 3},
"b": map[string]any{"c": 2, "d": 3},
"e": 4,
}
@@ -110,13 +115,13 @@ func TestFilter_ExcludeKeys(t *testing.T) {
result := excludeKeys(data, keys)
assert.NotContains(t, result, "a")
assert.NotContains(t, result["b"].(map[string]interface{}), "c")
assert.Contains(t, result["b"].(map[string]interface{}), "d")
assert.NotContains(t, result["b"].(map[string]any), "c")
assert.Contains(t, result["b"].(map[string]any), "d")
assert.Contains(t, result, "e")
}
func TestFilter_ExcludeKeys_NonExistentKey(t *testing.T) {
data := map[string]interface{}{"a": 1}
data := map[string]any{"a": 1}
keys := []string{"b"}
result := excludeKeys(data, keys)
+1 -1
View File
@@ -6,7 +6,7 @@ import (
"github.com/lovelaze/nebula-sync/internal/config"
)
func (target *target) FullSync(conf *config.Sync) (err error) {
func (target *target) FullSync(conf *config.Sync) error {
return target.sync(func() error {
return target.full(conf)
}, "full")
+5 -4
View File
@@ -3,6 +3,7 @@ package retry
import (
"errors"
"github.com/lovelaze/nebula-sync/internal/config"
"github.com/stretchr/testify/require"
"testing"
"time"
@@ -31,7 +32,7 @@ func TestWithRetry_DelayBetweenRetries(t *testing.T) {
elapsed := time.Since(start)
assert.NoError(t, err, "Expected success before max attempts")
require.NoError(t, err, "Expected success before max attempts")
assert.GreaterOrEqual(t, elapsed.Seconds(), 2.0, "Expected at least 2 seconds of delay between all retries")
assert.LessOrEqual(t, elapsed.Seconds(), 2.5, "Expected at most 2.5 seconds of delay between all retries")
}
@@ -50,7 +51,7 @@ func TestWithRetry_NoRetriesOnImmediateSuccess(t *testing.T) {
return nil
}, 5) // 5 attempts, 2-second delay
assert.NoError(t, err, "Expected no error when function succeeds immediately")
require.NoError(t, err, "Expected no error when function succeeds immediately")
assert.Equal(t, 1, counter, "Expected function to run only once without retries")
}
@@ -71,7 +72,7 @@ func TestWithRetry_SuccessAfterRetries(t *testing.T) {
return nil
}, 3) // 3 attempts, 1-second delay
assert.NoError(t, err, "Expected success before max attempts")
require.NoError(t, err, "Expected success before max attempts")
assert.Equal(t, 2, counter, "Expected function to retry once before success")
}
@@ -89,6 +90,6 @@ func TestWithRetry_MaxAttemptsFailure(t *testing.T) {
return errors.New("test error")
}, 3) // 3 attempts, 1-second delay
assert.Error(t, err, "Expected an error after max attempts")
require.Error(t, err, "Expected an error after max attempts")
assert.Equal(t, 3, counter, "Expected function to be retried 3 times")
}
+8 -7
View File
@@ -28,7 +28,8 @@ func NewTarget(primary pihole.Client, replicas []pihole.Client) Target {
}
}
func (target *target) sync(syncFunc func() error, mode string) (err error) {
func (target *target) sync(syncFunc func() error, mode string) error {
var err error
log.Info().Str("mode", mode).Int("replicas", len(target.Replicas)).Msg("Running sync")
defer func() {
@@ -45,7 +46,7 @@ func (target *target) sync(syncFunc func() error, mode string) (err error) {
return syncFunc()
}
func (target *target) authenticate() (err error) {
func (target *target) authenticate() error {
log.Info().Msg("Authenticating clients...")
if err := target.Primary.PostAuth(); err != nil {
return err
@@ -59,7 +60,7 @@ func (target *target) authenticate() (err error) {
}
}
return err
return nil
}
func (target *target) deleteSessions() {
@@ -84,7 +85,7 @@ func (target *target) syncTeleporters(gravitySettings *config.GravitySettings) e
return err
}
var teleporterRequest *model.PostTeleporterRequest = nil
var teleporterRequest *model.PostTeleporterRequest
if gravitySettings != nil {
teleporterRequest = createPostTeleporterRequest(gravitySettings)
}
@@ -166,18 +167,18 @@ func createPatchConfigRequest(config *config.ConfigSettings, configResponse *mod
return &model.PatchConfigRequest{Config: patchConfig}
}
func filterPatchConfigRequest(setting *config.ConfigSetting, json map[string]interface{}) map[string]interface{} {
func filterPatchConfigRequest(setting *config.ConfigSetting, json map[string]any) map[string]any {
if !setting.Enabled {
return nil
}
if setting.Filter != nil {
filteredJson, err := filter.ByType(setting.Filter.Type, setting.Filter.Keys, json)
filteredJSON, err := filter.ByType(setting.Filter.Type, setting.Filter.Keys, json)
if err != nil {
log.Warn().Err(err).Msg("Unable to filter json object")
return nil
}
return filteredJson
return filteredJSON
}
return json
+8 -8
View File
@@ -165,13 +165,13 @@ func Test_filterPatchConfigRequest_disabled(t *testing.T) {
}
func emptyConfigResponse() *model.ConfigResponse {
return &model.ConfigResponse{Config: map[string]interface{}{
"dns": map[string]interface{}{},
"dhcp": map[string]interface{}{},
"ntp": map[string]interface{}{},
"resolver": map[string]interface{}{},
"database": map[string]interface{}{},
"misc": map[string]interface{}{},
"debug": map[string]interface{}{},
return &model.ConfigResponse{Config: map[string]any{
"dns": map[string]any{},
"dhcp": map[string]any{},
"ntp": map[string]any{},
"resolver": map[string]any{},
"database": map[string]any{},
"misc": map[string]any{},
"debug": map[string]any{},
}}
}
+11 -5
View File
@@ -1,6 +1,7 @@
package webhook
import (
"context"
"crypto/tls"
"fmt"
"net/http"
@@ -12,6 +13,11 @@ import (
"github.com/rs/zerolog/log"
)
const (
timeout = 10 * time.Second
invalidHTTPStatusCodeThreshold = 400
)
type Client struct {
success config.WebhookRequest
failure config.WebhookRequest
@@ -23,7 +29,7 @@ func NewClient(c *config.WebhookSettings) *Client {
success: c.Success,
failure: c.Failure,
httpClient: &http.Client{
Timeout: 10 * time.Second,
Timeout: timeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: c.Client.SkipTLSVerification},
}},
@@ -51,18 +57,18 @@ func (c *Client) triggerFailure() error {
}
func invoke(client *http.Client, settings config.WebhookRequest) error {
if settings.Url == "" {
if settings.URL == "" {
return nil
}
log.Debug().
Str("url", settings.Url).
Str("url", settings.URL).
Str("method", settings.Method).
Str("body", settings.Body).
Interface("headers", settings.Headers).
Msg("Invoking webhook")
req, err := http.NewRequest(settings.Method, settings.Url, strings.NewReader(settings.Body))
req, err := http.NewRequestWithContext(context.Background(), settings.Method, settings.URL, strings.NewReader(settings.Body))
if err != nil {
return fmt.Errorf("create webhook request: %w", err)
}
@@ -79,7 +85,7 @@ func invoke(client *http.Client, settings config.WebhookRequest) error {
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
if resp.StatusCode >= invalidHTTPStatusCodeThreshold {
return fmt.Errorf("webhook returned status %d", resp.StatusCode)
}
+6 -6
View File
@@ -20,7 +20,7 @@ func TestWebhook(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHeaders = r.Header
buf, err := io.ReadAll(r.Body)
require.NoError(t, err)
assert.NoError(t, err)
receivedBody = string(buf)
w.WriteHeader(http.StatusOK)
}))
@@ -29,7 +29,7 @@ func TestWebhook(t *testing.T) {
// Create webhook settings
settings := &config.WebhookSettings{
Success: config.WebhookRequest{
Url: ts.URL,
URL: ts.URL,
Method: "POST",
Body: "success-body",
Headers: map[string]string{"X-Test": "success"},
@@ -53,7 +53,7 @@ func TestWebhook(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHeaders = r.Header
buf, err := io.ReadAll(r.Body)
require.NoError(t, err)
assert.NoError(t, err)
receivedBody = string(buf)
w.WriteHeader(http.StatusOK)
}))
@@ -61,7 +61,7 @@ func TestWebhook(t *testing.T) {
settings := &config.WebhookSettings{
Failure: config.WebhookRequest{
Url: ts.URL,
URL: ts.URL,
Method: "PUT",
Body: "failure-body",
Headers: map[string]string{"X-Test": "failure"},
@@ -80,7 +80,7 @@ func TestWebhook(t *testing.T) {
t.Run("empty url skips webhook", func(t *testing.T) {
settings := &config.WebhookSettings{
Success: config.WebhookRequest{
Url: "",
URL: "",
},
}
@@ -97,7 +97,7 @@ func TestWebhook(t *testing.T) {
settings := &config.WebhookSettings{
Success: config.WebhookRequest{
Url: ts.URL,
URL: ts.URL,
},
}
+4 -1
View File
@@ -2,8 +2,11 @@ package main
import (
"github.com/lovelaze/nebula-sync/cmd"
"os"
)
func main() {
cmd.Execute()
if err := cmd.Execute(); err != nil {
os.Exit(1)
}
}