diff --git a/cmd/traefik/plugins.go b/cmd/traefik/plugins.go index 2c3d5dd25..ef939d2a0 100644 --- a/cmd/traefik/plugins.go +++ b/cmd/traefik/plugins.go @@ -2,43 +2,62 @@ package main import ( "fmt" + "net/http" + "path/filepath" + "time" + "github.com/hashicorp/go-retryablehttp" + "github.com/rs/zerolog/log" "github.com/traefik/traefik/v3/pkg/config/static" + "github.com/traefik/traefik/v3/pkg/logs" "github.com/traefik/traefik/v3/pkg/plugins" ) const outputDir = "./plugins-storage/" func createPluginBuilder(staticConfiguration *static.Configuration) (*plugins.Builder, error) { - client, plgs, localPlgs, err := initPlugins(staticConfiguration) + manager, plgs, localPlgs, err := initPlugins(staticConfiguration) if err != nil { return nil, err } - return plugins.NewBuilder(client, plgs, localPlgs) + return plugins.NewBuilder(manager, plgs, localPlgs) } -func initPlugins(staticCfg *static.Configuration) (*plugins.Client, map[string]plugins.Descriptor, map[string]plugins.LocalDescriptor, error) { +func initPlugins(staticCfg *static.Configuration) (*plugins.Manager, map[string]plugins.Descriptor, map[string]plugins.LocalDescriptor, error) { err := checkUniquePluginNames(staticCfg.Experimental) if err != nil { return nil, nil, nil, err } - var client *plugins.Client + var manager *plugins.Manager plgs := map[string]plugins.Descriptor{} if hasPlugins(staticCfg) { - opts := plugins.ClientOptions{ + httpClient := retryablehttp.NewClient() + httpClient.Logger = logs.NewRetryableHTTPLogger(log.Logger) + httpClient.HTTPClient = &http.Client{Timeout: 10 * time.Second} + httpClient.RetryMax = 3 + + // Create separate downloader for HTTP operations + archivesPath := filepath.Join(outputDir, "archives") + downloader, err := plugins.NewRegistryDownloader(plugins.RegistryDownloaderOptions{ + HTTPClient: httpClient.HTTPClient, + ArchivesPath: archivesPath, + }) + if err != nil { + return nil, nil, nil, fmt.Errorf("unable to create plugin downloader: %w", err) + } + + opts := plugins.ManagerOptions{ Output: outputDir, } - - var err error - client, err = plugins.NewClient(opts) + manager, err = plugins.NewManager(downloader, opts) if err != nil { - return nil, nil, nil, fmt.Errorf("unable to create plugins client: %w", err) + return nil, nil, nil, fmt.Errorf("unable to create plugins manager: %w", err) } - err = plugins.SetupRemotePlugins(client, staticCfg.Experimental.Plugins) + err = plugins.SetupRemotePlugins(manager, staticCfg.Experimental.Plugins) if err != nil { return nil, nil, nil, fmt.Errorf("unable to set up plugins environment: %w", err) } @@ -57,7 +76,7 @@ func initPlugins(staticCfg *static.Configuration) (*plugins.Client, map[string]p localPlgs = staticCfg.Experimental.LocalPlugins } - return client, plgs, localPlgs, nil + return manager, plgs, localPlgs, nil } func checkUniquePluginNames(e *static.Experimental) error { diff --git a/docs/content/reference/install-configuration/configuration-options.md b/docs/content/reference/install-configuration/configuration-options.md index dd4a0f798..14673995a 100644 --- a/docs/content/reference/install-configuration/configuration-options.md +++ b/docs/content/reference/install-configuration/configuration-options.md @@ -128,6 +128,7 @@ THIS FILE MUST NOT BE EDITED BY HAND | experimental.localplugins._name_.settings.mounts | Directory to mount to the wasm guest. | | | experimental.localplugins._name_.settings.useunsafe | Allow the plugin to use unsafe package. | false | | experimental.otlplogs | Enables the OpenTelemetry logs integration. | false | +| experimental.plugins._name_.hash | plugin's hash to validate' | | | experimental.plugins._name_.modulename | plugin's module name. | | | experimental.plugins._name_.settings | Plugin's settings (works only for wasm plugins). | | | experimental.plugins._name_.settings.envs | Environment variables to forward to the wasm guest. | | diff --git a/pkg/plugins/builder.go b/pkg/plugins/builder.go index 8559dec90..96a4bf21e 100644 --- a/pkg/plugins/builder.go +++ b/pkg/plugins/builder.go @@ -28,7 +28,7 @@ type Builder struct { } // NewBuilder creates a new Builder. -func NewBuilder(client *Client, plugins map[string]Descriptor, localPlugins map[string]LocalDescriptor) (*Builder, error) { +func NewBuilder(manager *Manager, plugins map[string]Descriptor, localPlugins map[string]LocalDescriptor) (*Builder, error) { ctx := context.Background() pb := &Builder{ @@ -37,9 +37,9 @@ func NewBuilder(client *Client, plugins map[string]Descriptor, localPlugins map[ } for pName, desc := range plugins { - manifest, err := client.ReadManifest(desc.ModuleName) + manifest, err := manager.ReadManifest(desc.ModuleName) if err != nil { - _ = client.ResetAll() + _ = manager.ResetAll() return nil, fmt.Errorf("%s: failed to read manifest: %w", desc.ModuleName, err) } @@ -52,7 +52,7 @@ func NewBuilder(client *Client, plugins map[string]Descriptor, localPlugins map[ switch manifest.Type { case typeMiddleware: - middleware, err := newMiddlewareBuilder(logCtx, client.GoPath(), manifest, desc.ModuleName, desc.Settings) + middleware, err := newMiddlewareBuilder(logCtx, manager.GoPath(), manifest, desc.ModuleName, desc.Settings) if err != nil { return nil, err } @@ -60,7 +60,7 @@ func NewBuilder(client *Client, plugins map[string]Descriptor, localPlugins map[ pb.middlewareBuilders[pName] = middleware case typeProvider: - pBuilder, err := newProviderBuilder(logCtx, manifest, client.GoPath(), desc.Settings) + pBuilder, err := newProviderBuilder(logCtx, manifest, manager.GoPath(), desc.Settings) if err != nil { return nil, fmt.Errorf("%s: %w", desc.ModuleName, err) } diff --git a/pkg/plugins/downloader.go b/pkg/plugins/downloader.go new file mode 100644 index 000000000..3dce1df4c --- /dev/null +++ b/pkg/plugins/downloader.go @@ -0,0 +1,160 @@ +package plugins + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path" + "path/filepath" +) + +// PluginDownloader defines the interface for downloading and validating plugins from remote sources. +type PluginDownloader interface { + // Download downloads a plugin archive and returns its hash. + Download(ctx context.Context, pName, pVersion string) (string, error) + // Check checks the plugin archive integrity against a known hash. + Check(ctx context.Context, pName, pVersion, hash string) error +} + +// RegistryDownloaderOptions holds configuration options for creating a RegistryDownloader. +type RegistryDownloaderOptions struct { + HTTPClient *http.Client + ArchivesPath string +} + +// RegistryDownloader implements PluginDownloader for HTTP-based plugin downloads. +type RegistryDownloader struct { + httpClient *http.Client + baseURL *url.URL + archives string +} + +// NewRegistryDownloader creates a new HTTP-based plugin downloader. +func NewRegistryDownloader(opts RegistryDownloaderOptions) (*RegistryDownloader, error) { + baseURL, err := url.Parse(pluginsURL) + if err != nil { + return nil, err + } + + httpClient := opts.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + + return &RegistryDownloader{ + httpClient: httpClient, + baseURL: baseURL, + archives: opts.ArchivesPath, + }, nil +} + +// Download downloads a plugin archive. +func (d *RegistryDownloader) Download(ctx context.Context, pName, pVersion string) (string, error) { + filename := d.buildArchivePath(pName, pVersion) + + var hash string + _, err := os.Stat(filename) + if err != nil && !os.IsNotExist(err) { + return "", fmt.Errorf("failed to read archive %s: %w", filename, err) + } + + if err == nil { + hash, err = computeHash(filename) + if err != nil { + return "", fmt.Errorf("failed to compute hash: %w", err) + } + } + + endpoint, err := d.baseURL.Parse(path.Join(d.baseURL.Path, "download", pName, pVersion)) + if err != nil { + return "", fmt.Errorf("failed to parse endpoint URL: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + if hash != "" { + req.Header.Set(hashHeader, hash) + } + + resp, err := d.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("failed to call service: %w", err) + } + + defer func() { _ = resp.Body.Close() }() + + switch resp.StatusCode { + case http.StatusNotModified: + return hash, nil + case http.StatusOK: + err = os.MkdirAll(filepath.Dir(filename), 0o755) + if err != nil { + return "", fmt.Errorf("failed to create directory: %w", err) + } + + var file *os.File + file, err = os.Create(filename) + if err != nil { + return "", fmt.Errorf("failed to create file %q: %w", filename, err) + } + + defer func() { _ = file.Close() }() + + _, err = io.Copy(file, resp.Body) + if err != nil { + return "", fmt.Errorf("failed to write response: %w", err) + } + + hash, err = computeHash(filename) + if err != nil { + return "", fmt.Errorf("failed to compute hash: %w", err) + } + default: + data, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("error: %d: %s", resp.StatusCode, string(data)) + } + + return hash, nil +} + +// Check checks the plugin archive integrity. +func (d *RegistryDownloader) Check(ctx context.Context, pName, pVersion, hash string) error { + endpoint, err := d.baseURL.Parse(path.Join(d.baseURL.Path, "validate", pName, pVersion)) + if err != nil { + return fmt.Errorf("failed to parse endpoint URL: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + if hash != "" { + req.Header.Set(hashHeader, hash) + } + + resp, err := d.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to call service: %w", err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusOK { + return nil + } + + return errors.New("plugin integrity check failed") +} + +// buildArchivePath builds the path to a plugin archive file. +func (d *RegistryDownloader) buildArchivePath(pName, pVersion string) string { + return filepath.Join(d.archives, filepath.FromSlash(pName), pVersion+".zip") +} diff --git a/pkg/plugins/downloader_test.go b/pkg/plugins/downloader_test.go new file mode 100644 index 000000000..bcbd89424 --- /dev/null +++ b/pkg/plugins/downloader_test.go @@ -0,0 +1,159 @@ +package plugins + +import ( + "archive/zip" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHTTPPluginDownloader_Download(t *testing.T) { + tests := []struct { + name string + serverResponse func(w http.ResponseWriter, r *http.Request) + fileAlreadyExists bool + expectError bool + }{ + { + name: "successful download", + serverResponse: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/zip") + w.WriteHeader(http.StatusOK) + + require.NoError(t, fillDummyZip(w)) + }, + }, + { + name: "not modified response", + serverResponse: func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "", http.StatusNotModified) + }, + fileAlreadyExists: true, + }, + { + name: "server error", + serverResponse: func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "internal server error", http.StatusInternalServerError) + }, + expectError: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(test.serverResponse)) + defer server.Close() + + tempDir := t.TempDir() + archivesPath := filepath.Join(tempDir, "archives") + + if test.fileAlreadyExists { + createDummyZip(t, archivesPath) + } + + baseURL, err := url.Parse(server.URL) + require.NoError(t, err) + + downloader := &RegistryDownloader{ + httpClient: server.Client(), + baseURL: baseURL, + archives: archivesPath, + } + + ctx := t.Context() + hash, err := downloader.Download(ctx, "test/plugin", "v1.0.0") + + if test.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.NotEmpty(t, hash) + + // Check if archive file was created + archivePath := downloader.buildArchivePath("test/plugin", "v1.0.0") + assert.FileExists(t, archivePath) + } + }) + } +} + +func TestHTTPPluginDownloader_Check(t *testing.T) { + tests := []struct { + name string + serverResponse func(w http.ResponseWriter, r *http.Request) + expectError require.ErrorAssertionFunc + }{ + { + name: "successful check", + serverResponse: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }, + expectError: require.NoError, + }, + { + name: "failed check", + serverResponse: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + }, + expectError: require.Error, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(test.serverResponse)) + defer server.Close() + + tempDir := t.TempDir() + archivesPath := filepath.Join(tempDir, "archives") + + baseURL, err := url.Parse(server.URL) + require.NoError(t, err) + + downloader := &RegistryDownloader{ + httpClient: server.Client(), + baseURL: baseURL, + archives: archivesPath, + } + + ctx := t.Context() + + err = downloader.Check(ctx, "test/plugin", "v1.0.0", "testhash") + test.expectError(t, err) + }) + } +} + +func createDummyZip(t *testing.T, path string) { + t.Helper() + + err := os.MkdirAll(path+"/test/plugin/", 0o755) + require.NoError(t, err) + + zipfile, err := os.Create(path + "/test/plugin/v1.0.0.zip") + require.NoError(t, err) + defer zipfile.Close() + + err = fillDummyZip(zipfile) + require.NoError(t, err) +} + +func fillDummyZip(w io.Writer) error { + writer := zip.NewWriter(w) + + file, err := writer.Create("test.txt") + if err != nil { + return err + } + + _, _ = file.Write([]byte("test content")) + _ = writer.Close() + return nil +} diff --git a/pkg/plugins/client.go b/pkg/plugins/manager.go similarity index 52% rename from pkg/plugins/client.go rename to pkg/plugins/manager.go index 8d119ddec..2bbb9cbe7 100644 --- a/pkg/plugins/client.go +++ b/pkg/plugins/manager.go @@ -9,17 +9,10 @@ import ( "errors" "fmt" "io" - "net/http" - "net/url" "os" - "path" "path/filepath" "strings" - "time" - "github.com/hashicorp/go-retryablehttp" - "github.com/rs/zerolog/log" - "github.com/traefik/traefik/v3/pkg/logs" "golang.org/x/mod/module" "golang.org/x/mod/zip" "gopkg.in/yaml.v3" @@ -39,31 +32,26 @@ const ( hashHeader = "X-Plugin-Hash" ) -// ClientOptions the options of a Traefik plugins client. -type ClientOptions struct { +// ManagerOptions the options of a Traefik plugins manager. +type ManagerOptions struct { Output string } -// Client a Traefik plugins client. -type Client struct { - HTTPClient *http.Client - baseURL *url.URL +// Manager manages Traefik plugins lifecycle operations including storage, and manifest reading. +type Manager struct { + downloader PluginDownloader - archives string stateFile string - goPath string - sources string + + archives string + sources string + goPath string } -// NewClient creates a new Traefik plugins client. -func NewClient(opts ClientOptions) (*Client, error) { - baseURL, err := url.Parse(pluginsURL) - if err != nil { - return nil, err - } - +// NewManager creates a new Traefik plugins manager. +func NewManager(downloader PluginDownloader, opts ManagerOptions) (*Manager, error) { sourcesRootPath := filepath.Join(filepath.FromSlash(opts.Output), sourcesFolder) - err = resetDirectory(sourcesRootPath) + err := resetDirectory(sourcesRootPath) if err != nil { return nil, err } @@ -79,31 +67,48 @@ func NewClient(opts ClientOptions) (*Client, error) { return nil, fmt.Errorf("failed to create archives directory %s: %w", archivesPath, err) } - client := retryablehttp.NewClient() - client.Logger = logs.NewRetryableHTTPLogger(log.Logger) - client.HTTPClient = &http.Client{Timeout: 10 * time.Second} - client.RetryMax = 3 - - return &Client{ - HTTPClient: client.StandardClient(), - baseURL: baseURL, - - archives: archivesPath, - stateFile: filepath.Join(archivesPath, stateFilename), - - goPath: goPath, - sources: filepath.Join(goPath, goPathSrc), + return &Manager{ + downloader: downloader, + stateFile: filepath.Join(archivesPath, stateFilename), + archives: archivesPath, + sources: filepath.Join(goPath, goPathSrc), + goPath: goPath, }, nil } +// InstallPlugin download and unzip the given plugin. +func (m *Manager) InstallPlugin(ctx context.Context, plugin Descriptor) error { + hash, err := m.downloader.Download(ctx, plugin.ModuleName, plugin.Version) + if err != nil { + return fmt.Errorf("unable to download plugin %s: %w", plugin.ModuleName, err) + } + + if plugin.Hash != "" { + if plugin.Hash != hash { + return fmt.Errorf("invalid hash for plugin %s, expected %s, got %s", plugin.ModuleName, plugin.Hash, hash) + } + } else { + err = m.downloader.Check(ctx, plugin.ModuleName, plugin.Version, hash) + if err != nil { + return fmt.Errorf("unable to check archive integrity of the plugin %s: %w", plugin.ModuleName, err) + } + } + + if err = m.unzip(plugin.ModuleName, plugin.Version); err != nil { + return fmt.Errorf("unable to unzip plugin %s: %w", plugin.ModuleName, err) + } + + return nil +} + // GoPath gets the plugins GoPath. -func (c *Client) GoPath() string { - return c.goPath +func (m *Manager) GoPath() string { + return m.goPath } // ReadManifest reads a plugin manifest. -func (c *Client) ReadManifest(moduleName string) (*Manifest, error) { - return ReadManifest(c.goPath, moduleName) +func (m *Manager) ReadManifest(moduleName string) (*Manifest, error) { + return ReadManifest(m.goPath, moduleName) } // ReadManifest reads a plugin manifest. @@ -126,114 +131,74 @@ func ReadManifest(goPath, moduleName string) (*Manifest, error) { return m, nil } -// Download downloads a plugin archive. -func (c *Client) Download(ctx context.Context, pName, pVersion string) (string, error) { - filename := c.buildArchivePath(pName, pVersion) - - var hash string - _, err := os.Stat(filename) - if err != nil && !os.IsNotExist(err) { - return "", fmt.Errorf("failed to read archive %s: %w", filename, err) - } - - if err == nil { - hash, err = computeHash(filename) - if err != nil { - return "", fmt.Errorf("failed to compute hash: %w", err) - } - } - - endpoint, err := c.baseURL.Parse(path.Join(c.baseURL.Path, "download", pName, pVersion)) - if err != nil { - return "", fmt.Errorf("failed to parse endpoint URL: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) - if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) - } - - if hash != "" { - req.Header.Set(hashHeader, hash) - } - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return "", fmt.Errorf("failed to call service: %w", err) - } - - defer func() { _ = resp.Body.Close() }() - - switch resp.StatusCode { - case http.StatusNotModified: - // noop - return hash, nil - - case http.StatusOK: - err = os.MkdirAll(filepath.Dir(filename), 0o755) - if err != nil { - return "", fmt.Errorf("failed to create directory: %w", err) - } - - var file *os.File - file, err = os.Create(filename) - if err != nil { - return "", fmt.Errorf("failed to create file %q: %w", filename, err) - } - - defer func() { _ = file.Close() }() - - _, err = io.Copy(file, resp.Body) - if err != nil { - return "", fmt.Errorf("failed to write response: %w", err) - } - - hash, err = computeHash(filename) - if err != nil { - return "", fmt.Errorf("failed to compute hash: %w", err) - } - - return hash, nil - - default: - data, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("error: %d: %s", resp.StatusCode, string(data)) - } -} - -// Check checks the plugin archive integrity. -func (c *Client) Check(ctx context.Context, pName, pVersion, hash string) error { - endpoint, err := c.baseURL.Parse(path.Join(c.baseURL.Path, "validate", pName, pVersion)) - if err != nil { - return fmt.Errorf("failed to parse endpoint URL: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - if hash != "" { - req.Header.Set(hashHeader, hash) - } - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return fmt.Errorf("failed to call service: %w", err) - } - - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode == http.StatusOK { +// CleanArchives cleans plugins archives. +func (m *Manager) CleanArchives(plugins map[string]Descriptor) error { + if _, err := os.Stat(m.stateFile); os.IsNotExist(err) { return nil } - return errors.New("plugin integrity check failed") + stateFile, err := os.Open(m.stateFile) + if err != nil { + return fmt.Errorf("failed to open state file %s: %w", m.stateFile, err) + } + + previous := make(map[string]string) + err = json.NewDecoder(stateFile).Decode(&previous) + if err != nil { + return fmt.Errorf("failed to decode state file %s: %w", m.stateFile, err) + } + + for pName, pVersion := range previous { + for _, desc := range plugins { + if desc.ModuleName == pName && desc.Version != pVersion { + archivePath := m.buildArchivePath(pName, pVersion) + if err = os.RemoveAll(archivePath); err != nil { + return fmt.Errorf("failed to remove archive %s: %w", archivePath, err) + } + } + } + } + + return nil } -// Unzip unzip a plugin archive. -func (c *Client) Unzip(pName, pVersion string) error { - err := c.unzipModule(pName, pVersion) +// WriteState writes the plugins state files. +func (m *Manager) WriteState(plugins map[string]Descriptor) error { + state := make(map[string]string) + + for _, descriptor := range plugins { + state[descriptor.ModuleName] = descriptor.Version + } + + mp, err := json.MarshalIndent(state, "", " ") + if err != nil { + return fmt.Errorf("unable to marshal plugin state: %w", err) + } + + return os.WriteFile(m.stateFile, mp, 0o600) +} + +// ResetAll resets all plugins related directories. +func (m *Manager) ResetAll() error { + if m.goPath == "" { + return errors.New("goPath is empty") + } + + err := resetDirectory(filepath.Join(m.goPath, "..")) + if err != nil { + return fmt.Errorf("unable to reset plugins GoPath directory %s: %w", m.goPath, err) + } + + err = resetDirectory(m.archives) + if err != nil { + return fmt.Errorf("unable to reset plugins archives directory: %w", err) + } + + return nil +} + +func (m *Manager) unzip(pName, pVersion string) error { + err := m.unzipModule(pName, pVersion) if err == nil { return nil } @@ -241,18 +206,18 @@ func (c *Client) Unzip(pName, pVersion string) error { // Unzip as a generic archive if the module unzip fails. // This is useful for plugins that have vendor directories or other structures. // This is also useful for wasm plugins. - return c.unzipArchive(pName, pVersion) + return m.unzipArchive(pName, pVersion) } -func (c *Client) unzipModule(pName, pVersion string) error { - src := c.buildArchivePath(pName, pVersion) - dest := filepath.Join(c.sources, filepath.FromSlash(pName)) +func (m *Manager) unzipModule(pName, pVersion string) error { + src := m.buildArchivePath(pName, pVersion) + dest := filepath.Join(m.sources, filepath.FromSlash(pName)) return zip.Unzip(dest, module.Version{Path: pName, Version: pVersion}, src) } -func (c *Client) unzipArchive(pName, pVersion string) error { - zipPath := c.buildArchivePath(pName, pVersion) +func (m *Manager) unzipArchive(pName, pVersion string) error { + zipPath := m.buildArchivePath(pName, pVersion) archive, err := zipa.OpenReader(zipPath) if err != nil { @@ -261,10 +226,10 @@ func (c *Client) unzipArchive(pName, pVersion string) error { defer func() { _ = archive.Close() }() - dest := filepath.Join(c.sources, filepath.FromSlash(pName)) + dest := filepath.Join(m.sources, filepath.FromSlash(pName)) for _, f := range archive.File { - err = unzipFile(f, dest) + err = m.unzipFile(f, dest) if err != nil { return fmt.Errorf("unable to unzip %s: %w", f.Name, err) } @@ -273,7 +238,7 @@ func (c *Client) unzipArchive(pName, pVersion string) error { return nil } -func unzipFile(f *zipa.File, dest string) error { +func (m *Manager) unzipFile(f *zipa.File, dest string) error { rc, err := f.Open() if err != nil { return err @@ -341,74 +306,8 @@ func unzipFile(f *zipa.File, dest string) error { return nil } -// CleanArchives cleans plugins archives. -func (c *Client) CleanArchives(plugins map[string]Descriptor) error { - if _, err := os.Stat(c.stateFile); os.IsNotExist(err) { - return nil - } - - stateFile, err := os.Open(c.stateFile) - if err != nil { - return fmt.Errorf("failed to open state file %s: %w", c.stateFile, err) - } - - previous := make(map[string]string) - err = json.NewDecoder(stateFile).Decode(&previous) - if err != nil { - return fmt.Errorf("failed to decode state file %s: %w", c.stateFile, err) - } - - for pName, pVersion := range previous { - for _, desc := range plugins { - if desc.ModuleName == pName && desc.Version != pVersion { - archivePath := c.buildArchivePath(pName, pVersion) - if err = os.RemoveAll(archivePath); err != nil { - return fmt.Errorf("failed to remove archive %s: %w", archivePath, err) - } - } - } - } - - return nil -} - -// WriteState writes the plugins state files. -func (c *Client) WriteState(plugins map[string]Descriptor) error { - m := make(map[string]string) - - for _, descriptor := range plugins { - m[descriptor.ModuleName] = descriptor.Version - } - - mp, err := json.MarshalIndent(m, "", " ") - if err != nil { - return fmt.Errorf("unable to marshal plugin state: %w", err) - } - - return os.WriteFile(c.stateFile, mp, 0o600) -} - -// ResetAll resets all plugins related directories. -func (c *Client) ResetAll() error { - if c.goPath == "" { - return errors.New("goPath is empty") - } - - err := resetDirectory(filepath.Join(c.goPath, "..")) - if err != nil { - return fmt.Errorf("unable to reset plugins GoPath directory %s: %w", c.goPath, err) - } - - err = resetDirectory(c.archives) - if err != nil { - return fmt.Errorf("unable to reset plugins archives directory: %w", err) - } - - return nil -} - -func (c *Client) buildArchivePath(pName, pVersion string) string { - return filepath.Join(c.archives, filepath.FromSlash(pName), pVersion+".zip") +func (m *Manager) buildArchivePath(pName, pVersion string) string { + return filepath.Join(m.archives, filepath.FromSlash(pName), pVersion+".zip") } func resetDirectory(dir string) error { diff --git a/pkg/plugins/manager_test.go b/pkg/plugins/manager_test.go new file mode 100644 index 000000000..5100b2a6b --- /dev/null +++ b/pkg/plugins/manager_test.go @@ -0,0 +1,341 @@ +package plugins + +import ( + zipa "archive/zip" + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +// mockDownloader is a test implementation of PluginDownloader +type mockDownloader struct { + downloadFunc func(ctx context.Context, pName, pVersion string) (string, error) + checkFunc func(ctx context.Context, pName, pVersion, hash string) error +} + +func (m *mockDownloader) Download(ctx context.Context, pName, pVersion string) (string, error) { + if m.downloadFunc != nil { + return m.downloadFunc(ctx, pName, pVersion) + } + return "mockhash", nil +} + +func (m *mockDownloader) Check(ctx context.Context, pName, pVersion, hash string) error { + if m.checkFunc != nil { + return m.checkFunc(ctx, pName, pVersion, hash) + } + return nil +} + +func TestPluginManager_ReadManifest(t *testing.T) { + tempDir := t.TempDir() + opts := ManagerOptions{Output: tempDir} + + downloader := &mockDownloader{} + manager, err := NewManager(downloader, opts) + require.NoError(t, err) + + moduleName := "github.com/test/plugin" + pluginPath := filepath.Join(manager.goPath, "src", moduleName) + err = os.MkdirAll(pluginPath, 0o755) + require.NoError(t, err) + + manifest := &Manifest{ + DisplayName: "Test Plugin", + Type: "middleware", + Import: "github.com/test/plugin", + Summary: "A test plugin", + TestData: map[string]interface{}{ + "test": "data", + }, + } + + manifestPath := filepath.Join(pluginPath, pluginManifest) + manifestData, err := yaml.Marshal(manifest) + require.NoError(t, err) + err = os.WriteFile(manifestPath, manifestData, 0o644) + require.NoError(t, err) + + readManifest, err := manager.ReadManifest(moduleName) + require.NoError(t, err) + assert.Equal(t, manifest.DisplayName, readManifest.DisplayName) + assert.Equal(t, manifest.Type, readManifest.Type) + assert.Equal(t, manifest.Import, readManifest.Import) + assert.Equal(t, manifest.Summary, readManifest.Summary) +} + +func TestPluginManager_ReadManifest_NotFound(t *testing.T) { + tempDir := t.TempDir() + opts := ManagerOptions{Output: tempDir} + + downloader := &mockDownloader{} + manager, err := NewManager(downloader, opts) + require.NoError(t, err) + + _, err = manager.ReadManifest("nonexistent/plugin") + assert.Error(t, err) +} + +func TestPluginManager_CleanArchives(t *testing.T) { + tempDir := t.TempDir() + opts := ManagerOptions{Output: tempDir} + + downloader := &mockDownloader{} + manager, err := NewManager(downloader, opts) + require.NoError(t, err) + + testPlugin1 := "test/plugin1" + testPlugin2 := "test/plugin2" + + archive1Dir := filepath.Join(manager.archives, "test", "plugin1") + archive2Dir := filepath.Join(manager.archives, "test", "plugin2") + err = os.MkdirAll(archive1Dir, 0o755) + require.NoError(t, err) + err = os.MkdirAll(archive2Dir, 0o755) + require.NoError(t, err) + + archive1Old := filepath.Join(archive1Dir, "v1.0.0.zip") + archive1New := filepath.Join(archive1Dir, "v2.0.0.zip") + archive2 := filepath.Join(archive2Dir, "v1.0.0.zip") + + err = os.WriteFile(archive1Old, []byte("old archive"), 0o644) + require.NoError(t, err) + err = os.WriteFile(archive1New, []byte("new archive"), 0o644) + require.NoError(t, err) + err = os.WriteFile(archive2, []byte("archive 2"), 0o644) + require.NoError(t, err) + + state := map[string]string{ + testPlugin1: "v1.0.0", + testPlugin2: "v1.0.0", + } + stateData, err := json.MarshalIndent(state, "", " ") + require.NoError(t, err) + err = os.WriteFile(manager.stateFile, stateData, 0o600) + require.NoError(t, err) + + currentPlugins := map[string]Descriptor{ + "plugin1": { + ModuleName: testPlugin1, + Version: "v2.0.0", + }, + "plugin2": { + ModuleName: testPlugin2, + Version: "v1.0.0", + }, + } + + err = manager.CleanArchives(currentPlugins) + require.NoError(t, err) + + assert.NoFileExists(t, archive1Old) + assert.FileExists(t, archive1New) + assert.FileExists(t, archive2) +} + +func TestPluginManager_WriteState(t *testing.T) { + tempDir := t.TempDir() + opts := ManagerOptions{Output: tempDir} + + downloader := &mockDownloader{} + manager, err := NewManager(downloader, opts) + require.NoError(t, err) + + plugins := map[string]Descriptor{ + "plugin1": { + ModuleName: "test/plugin1", + Version: "v1.0.0", + }, + "plugin2": { + ModuleName: "test/plugin2", + Version: "v2.0.0", + }, + } + + err = manager.WriteState(plugins) + require.NoError(t, err) + + assert.FileExists(t, manager.stateFile) + + data, err := os.ReadFile(manager.stateFile) + require.NoError(t, err) + + var state map[string]string + err = json.Unmarshal(data, &state) + require.NoError(t, err) + + expectedState := map[string]string{ + "test/plugin1": "v1.0.0", + "test/plugin2": "v2.0.0", + } + assert.Equal(t, expectedState, state) +} + +func TestPluginManager_ResetAll(t *testing.T) { + tempDir := t.TempDir() + opts := ManagerOptions{Output: tempDir} + + downloader := &mockDownloader{} + manager, err := NewManager(downloader, opts) + require.NoError(t, err) + + testFile := filepath.Join(manager.GoPath(), "test.txt") + err = os.WriteFile(testFile, []byte("test"), 0o644) + require.NoError(t, err) + + archiveFile := filepath.Join(manager.archives, "test.zip") + err = os.WriteFile(archiveFile, []byte("archive"), 0o644) + require.NoError(t, err) + + err = manager.ResetAll() + require.NoError(t, err) + + assert.DirExists(t, manager.archives) + assert.NoFileExists(t, testFile) + assert.NoFileExists(t, archiveFile) +} + +func TestPluginManager_InstallPlugin(t *testing.T) { + tests := []struct { + name string + plugin Descriptor + downloadFunc func(ctx context.Context, pName, pVersion string) (string, error) + checkFunc func(ctx context.Context, pName, pVersion, hash string) error + setupArchive func(t *testing.T, archivePath string) + expectError bool + errorMsg string + }{ + { + name: "successful installation", + plugin: Descriptor{ + ModuleName: "github.com/test/plugin", + Version: "v1.0.0", + Hash: "expected-hash", + }, + downloadFunc: func(ctx context.Context, pName, pVersion string) (string, error) { + return "expected-hash", nil + }, + checkFunc: func(ctx context.Context, pName, pVersion, hash string) error { + return nil + }, + setupArchive: func(t *testing.T, archivePath string) { + t.Helper() + + // Create a valid zip archive + err := os.MkdirAll(filepath.Dir(archivePath), 0o755) + require.NoError(t, err) + + file, err := os.Create(archivePath) + require.NoError(t, err) + defer file.Close() + + // Write a minimal zip file with a test file + writer := zipa.NewWriter(file) + defer writer.Close() + + fileWriter, err := writer.Create("test-module-v1.0.0/main.go") + require.NoError(t, err) + _, err = fileWriter.Write([]byte("package main\n\nfunc main() {}\n")) + require.NoError(t, err) + }, + expectError: false, + }, + { + name: "download error", + plugin: Descriptor{ + ModuleName: "github.com/test/plugin", + Version: "v1.0.0", + }, + downloadFunc: func(ctx context.Context, pName, pVersion string) (string, error) { + return "", assert.AnError + }, + expectError: true, + errorMsg: "unable to download plugin", + }, + { + name: "check error", + plugin: Descriptor{ + ModuleName: "github.com/test/plugin", + Version: "v1.0.0", + Hash: "expected-hash", + }, + downloadFunc: func(ctx context.Context, pName, pVersion string) (string, error) { + return "actual-hash", nil + }, + checkFunc: func(ctx context.Context, pName, pVersion, hash string) error { + return assert.AnError + }, + expectError: true, + errorMsg: "invalid hash for plugin", + }, + { + name: "unzip error - invalid archive", + plugin: Descriptor{ + ModuleName: "github.com/test/plugin", + Version: "v1.0.0", + }, + downloadFunc: func(ctx context.Context, pName, pVersion string) (string, error) { + return "test-hash", nil + }, + checkFunc: func(ctx context.Context, pName, pVersion, hash string) error { + return nil + }, + setupArchive: func(t *testing.T, archivePath string) { + t.Helper() + + // Create an invalid zip archive + err := os.MkdirAll(filepath.Dir(archivePath), 0o755) + require.NoError(t, err) + err = os.WriteFile(archivePath, []byte("invalid zip content"), 0o644) + require.NoError(t, err) + }, + expectError: true, + errorMsg: "unable to unzip plugin", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tempDir := t.TempDir() + opts := ManagerOptions{Output: tempDir} + + downloader := &mockDownloader{ + downloadFunc: test.downloadFunc, + checkFunc: test.checkFunc, + } + + manager, err := NewManager(downloader, opts) + require.NoError(t, err) + + // Setup archive if needed + if test.setupArchive != nil { + archivePath := filepath.Join(manager.archives, + filepath.FromSlash(test.plugin.ModuleName), + test.plugin.Version+".zip") + test.setupArchive(t, archivePath) + } + + ctx := t.Context() + err = manager.InstallPlugin(ctx, test.plugin) + + if test.expectError { + assert.Error(t, err) + if test.errorMsg != "" { + assert.Contains(t, err.Error(), test.errorMsg) + } + } else { + assert.NoError(t, err) + + // Verify that plugin sources were extracted + sourcePath := filepath.Join(manager.sources, filepath.FromSlash(test.plugin.ModuleName)) + assert.DirExists(t, sourcePath) + } + }) + } +} diff --git a/pkg/plugins/plugins.go b/pkg/plugins/plugins.go index 367b6c46c..f7d543154 100644 --- a/pkg/plugins/plugins.go +++ b/pkg/plugins/plugins.go @@ -13,13 +13,13 @@ import ( const localGoPath = "./plugins-local/" // SetupRemotePlugins setup remote plugins environment. -func SetupRemotePlugins(client *Client, plugins map[string]Descriptor) error { +func SetupRemotePlugins(manager *Manager, plugins map[string]Descriptor) error { err := checkRemotePluginsConfiguration(plugins) if err != nil { return fmt.Errorf("invalid configuration: %w", err) } - err = client.CleanArchives(plugins) + err = manager.CleanArchives(plugins) if err != nil { return fmt.Errorf("unable to clean archives: %w", err) } @@ -27,35 +27,20 @@ func SetupRemotePlugins(client *Client, plugins map[string]Descriptor) error { ctx := context.Background() for pAlias, desc := range plugins { - log.Ctx(ctx).Debug().Msgf("Loading of plugin: %s: %s@%s", pAlias, desc.ModuleName, desc.Version) + log.Ctx(ctx).Debug().Msgf("Installing plugin: %s: %s@%s", pAlias, desc.ModuleName, desc.Version) - hash, err := client.Download(ctx, desc.ModuleName, desc.Version) - if err != nil { - _ = client.ResetAll() - return fmt.Errorf("unable to download plugin %s: %w", desc.ModuleName, err) - } - - err = client.Check(ctx, desc.ModuleName, desc.Version, hash) - if err != nil { - _ = client.ResetAll() - return fmt.Errorf("unable to check archive integrity of the plugin %s: %w", desc.ModuleName, err) + if err = manager.InstallPlugin(ctx, desc); err != nil { + _ = manager.ResetAll() + return fmt.Errorf("unable to install plugin %s: %w", pAlias, err) } } - err = client.WriteState(plugins) + err = manager.WriteState(plugins) if err != nil { - _ = client.ResetAll() + _ = manager.ResetAll() return fmt.Errorf("unable to write plugins state: %w", err) } - for _, desc := range plugins { - err = client.Unzip(desc.ModuleName, desc.Version) - if err != nil { - _ = client.ResetAll() - return fmt.Errorf("unable to unzip archive: %w", err) - } - } - return nil } diff --git a/pkg/plugins/types.go b/pkg/plugins/types.go index ccae8dce4..75bb589b3 100644 --- a/pkg/plugins/types.go +++ b/pkg/plugins/types.go @@ -24,6 +24,9 @@ type Descriptor struct { // Version (required) Version string `description:"plugin's version." json:"version,omitempty" toml:"version,omitempty" yaml:"version,omitempty" export:"true"` + // Hash (optional) + Hash string `description:"plugin's hash to validate'" json:"hash,omitempty" toml:"hash,omitempty" yaml:"hash,omitempty" export:"true"` + // Settings (optional) Settings Settings `description:"Plugin's settings (works only for wasm plugins)." json:"settings,omitempty" toml:"settings,omitempty" yaml:"settings,omitempty" export:"true"` }