From 93b17415c8a562b471b7964d43a360ee81b090e3 Mon Sep 17 00:00:00 2001 From: Rhea Ghosh Date: Fri, 24 Apr 2026 10:41:29 -0500 Subject: [PATCH] hackery --- cmd/tailscale/cli/drive.go | 173 ++++++++++++++++++- drive/drive_clone.go | 1 + drive/drive_view.go | 6 + drive/remote.go | 15 +- drive/remote_test.go | 16 ++ drive/share_access.go | 235 ++++++++++++++++++++++++++ drive/share_access_test.go | 303 ++++++++++++++++++++++++++++++++++ ipn/ipnlocal/peerapi_drive.go | 3 + 8 files changed, 739 insertions(+), 13 deletions(-) create mode 100644 drive/share_access.go create mode 100644 drive/share_access_test.go diff --git a/cmd/tailscale/cli/drive.go b/cmd/tailscale/cli/drive.go index 280ff3172..c290dd495 100644 --- a/cmd/tailscale/cli/drive.go +++ b/cmd/tailscale/cli/drive.go @@ -7,8 +7,10 @@ package cli import ( "context" + "flag" "fmt" "path/filepath" + "sort" "strings" "github.com/peterbourgon/ff/v3/ffcli" @@ -16,7 +18,7 @@ import ( ) const ( - driveShareUsage = "tailscale drive share " + driveShareUsage = "tailscale drive share [--users user1,user2 | --group groupname] " driveRenameUsage = "tailscale drive rename " driveUnshareUsage = "tailscale drive unshare " driveListUsage = "tailscale drive list" @@ -27,6 +29,10 @@ func init() { } func driveCmd() *ffcli.Command { + shareFlags := flag.NewFlagSet("share", flag.ContinueOnError) + usersFlag := shareFlags.String("users", "", "comma-separated list of users to share with (share name auto-generated)") + groupFlag := shareFlags.String("group", "", "group name to share with (share name auto-generated, only group members can access)") + return &ffcli.Command{ Name: "drive", ShortHelp: "Share a directory with your tailnet", @@ -42,8 +48,11 @@ func driveCmd() *ffcli.Command { { Name: "share", ShortUsage: driveShareUsage, - Exec: runDriveShare, - ShortHelp: "[ALPHA] Create or modify a share", + FlagSet: shareFlags, + Exec: func(ctx context.Context, args []string) error { + return runDriveShare(ctx, args, *usersFlag, *groupFlag) + }, + ShortHelp: "[ALPHA] Create or modify a share", }, { Name: "rename", @@ -68,12 +77,54 @@ func driveCmd() *ffcli.Command { } // runDriveShare is the entry point for the "tailscale drive share" command. -func runDriveShare(ctx context.Context, args []string) error { - if len(args) != 2 { - return fmt.Errorf("usage: %s", driveShareUsage) +func runDriveShare(ctx context.Context, args []string, usersFlag, groupFlag string) error { + if usersFlag != "" && groupFlag != "" { + return fmt.Errorf("cannot specify both --users and --group") } - name, path := args[0], args[1] + var name, path string + var isGroup bool + + switch { + case usersFlag != "": + // --users joe,rhea → name = "joe+rhea", path from args[0] + if len(args) != 1 { + return fmt.Errorf("usage: tailscale drive share --users user1,user2 ") + } + users := strings.Split(usersFlag, ",") + for i, u := range users { + users[i] = strings.TrimSpace(u) + if users[i] == "" { + return fmt.Errorf("empty username in --users flag") + } + } + if err := validateUsers(ctx, users); err != nil { + return err + } + sort.Strings(users) + name = strings.Join(users, "+") + path = args[0] + + case groupFlag != "": + // --group eng → name = "eng", path from args[0] + if len(args) != 1 { + return fmt.Errorf("usage: tailscale drive share --group groupname ") + } + if err := validateGroup(ctx, groupFlag); err != nil { + return err + } + name = groupFlag + path = args[0] + isGroup = true + + default: + // Traditional: + if len(args) != 2 { + return fmt.Errorf("usage: %s", driveShareUsage) + } + name = args[0] + path = args[1] + } absolutePath, err := filepath.Abs(path) if err != nil { @@ -81,8 +132,9 @@ func runDriveShare(ctx context.Context, args []string) error { } err = localClient.DriveShareSet(ctx, &drive.Share{ - Name: name, - Path: absolutePath, + Name: name, + Path: absolutePath, + IsGroup: isGroup, }) if err == nil { fmt.Printf("Sharing %q as %q\n", path, name) @@ -90,6 +142,109 @@ func runDriveShare(ctx context.Context, args []string) error { return err } +// validateUsers checks that all specified usernames exist in the tailnet and +// resolves display names. It modifies users in place, replacing each entry +// with its resolved display name (which may include a domain qualifier for +// disambiguation). It returns an error if any user is unknown or ambiguous. +func validateUsers(ctx context.Context, users []string) error { + status, err := localClient.Status(ctx) + if err != nil { + return fmt.Errorf("failed to get tailnet status: %w", err) + } + + tailnetDomain := "" + if status.CurrentTailnet != nil { + tailnetDomain = status.CurrentTailnet.Name + } + + // Build a map from short name to list of login names. + type userInfo struct { + loginName string + displayName string + } + shortToUsers := make(map[string][]userInfo) + for _, u := range status.User { + short := drive.LoginShortName(u.LoginName) + display := drive.LoginDisplayName(u.LoginName, tailnetDomain) + shortToUsers[short] = append(shortToUsers[short], userInfo{ + loginName: u.LoginName, + displayName: display, + }) + } + + // Also build a lookup by display name for users specifying name(domain). + displayToUser := make(map[string]userInfo) + for _, infos := range shortToUsers { + for _, info := range infos { + displayToUser[info.displayName] = info + } + } + + for i, u := range users { + // Check if user specified name(domain) form. + if strings.Contains(u, "(") && strings.Contains(u, ")") { + if _, ok := displayToUser[u]; !ok { + known := make([]string, 0) + for d := range displayToUser { + known = append(known, d) + } + sort.Strings(known) + return fmt.Errorf("unknown user %q\nvalid users: %s", u, strings.Join(known, ", ")) + } + users[i] = u + continue + } + + // Plain short name lookup. + matches, ok := shortToUsers[u] + if !ok || len(matches) == 0 { + known := make([]string, 0, len(shortToUsers)) + for k := range shortToUsers { + known = append(known, k) + } + sort.Strings(known) + return fmt.Errorf("unknown user %q\nvalid users: %s", u, strings.Join(known, ", ")) + } + if len(matches) == 1 { + users[i] = matches[0].displayName + continue + } + // Ambiguous: multiple users share the same short name. + options := make([]string, len(matches)) + for j, m := range matches { + options[j] = m.displayName + } + sort.Strings(options) + return fmt.Errorf("ambiguous user %q, did you mean: %s?", u, strings.Join(options, " or ")) + } + return nil +} + +// validateGroup checks that the specified group exists in the tailnet. +func validateGroup(ctx context.Context, group string) error { + status, err := localClient.Status(ctx) + if err != nil { + return fmt.Errorf("failed to get tailnet status: %w", err) + } + + knownGroups := make(map[string]bool) + for _, u := range status.User { + for _, g := range u.Groups { + knownGroups[drive.GroupShortName(g)] = true + } + } + + if !knownGroups[group] { + known := make([]string, 0, len(knownGroups)) + for k := range knownGroups { + known = append(known, k) + } + sort.Strings(known) + return fmt.Errorf("unknown group: %s\nvalid groups: %s", group, strings.Join(known, ", ")) + } + return nil +} + // runDriveUnshare is the entry point for the "tailscale drive unshare" command. func runDriveUnshare(ctx context.Context, args []string) error { if len(args) != 1 { diff --git a/drive/drive_clone.go b/drive/drive_clone.go index 724ebc386..ec9945e92 100644 --- a/drive/drive_clone.go +++ b/drive/drive_clone.go @@ -23,6 +23,7 @@ var _ShareCloneNeedsRegeneration = Share(struct { Path string As string BookmarkData []byte + IsGroup bool }{}) // Clone duplicates src into dst and reports whether it succeeded. diff --git a/drive/drive_view.go b/drive/drive_view.go index 253a2955b..7c22ef6e6 100644 --- a/drive/drive_view.go +++ b/drive/drive_view.go @@ -105,10 +105,16 @@ func (v ShareView) BookmarkData() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.ж.BookmarkData) } +// IsGroup indicates that this share's name corresponds to a group +// identity. When true, only members of the matching group can access +// the share. +func (v ShareView) IsGroup() bool { return v.ж.IsGroup } + // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _ShareViewNeedsRegeneration = Share(struct { Name string Path string As string BookmarkData []byte + IsGroup bool }{}) diff --git a/drive/remote.go b/drive/remote.go index 5f34d0023..d1f8388e8 100644 --- a/drive/remote.go +++ b/drive/remote.go @@ -17,7 +17,7 @@ var ( // for testing. DisallowShareAs = false ErrDriveNotEnabled = errors.New("Taildrive not enabled") - ErrInvalidShareName = errors.New("Share names may only contain the letters a-z, underscore _, parentheses (), or spaces") + ErrInvalidShareName = errors.New("Share names may only contain the letters a-z, underscore _, parentheses (), plus +, or spaces") ) // AllowShareAs reports whether sharing files as a specific user is allowed. @@ -46,6 +46,11 @@ type Share struct { // hold on to a security-scoped bookmark. That bookmark is stored here. See // https://developer.apple.com/documentation/security/app_sandbox/accessing_files_from_the_macos_app_sandbox#4144043 BookmarkData []byte `json:"bookmarkData,omitempty"` + + // IsGroup indicates that this share's name corresponds to a group + // identity. When true, only members of the matching group can access + // the share. + IsGroup bool `json:"isGroup,omitempty"` } func ShareViewsEqual(a, b ShareView) bool { @@ -55,7 +60,7 @@ func ShareViewsEqual(a, b ShareView) bool { if !a.Valid() || !b.Valid() { return false } - return a.Name() == b.Name() && a.Path() == b.Path() && a.As() == b.As() && a.BookmarkData().Equal(b.ж.BookmarkData) + return a.Name() == b.Name() && a.Path() == b.Path() && a.As() == b.As() && a.BookmarkData().Equal(b.ж.BookmarkData) && a.IsGroup() == b.IsGroup() } func SharesEqual(a, b *Share) bool { @@ -65,7 +70,7 @@ func SharesEqual(a, b *Share) bool { if a == nil || b == nil { return false } - return a.Name == b.Name && a.Path == b.Path && a.As == b.As && bytes.Equal(a.BookmarkData, b.BookmarkData) + return a.Name == b.Name && a.Path == b.Path && a.As == b.As && bytes.Equal(a.BookmarkData, b.BookmarkData) && a.IsGroup == b.IsGroup } func CompareShares(a, b *Share) int { @@ -124,6 +129,8 @@ func NormalizeShareName(name string) (string, error) { return "", ErrInvalidShareName } + name = NormalizeShareNameOrder(name) + return name, nil } @@ -136,7 +143,7 @@ func validShareName(name string) bool { continue } switch r { - case '_', ' ', '(', ')': + case '_', ' ', '(', ')', '+': continue } return false diff --git a/drive/remote_test.go b/drive/remote_test.go index c0de1723a..bd409140e 100644 --- a/drive/remote_test.go +++ b/drive/remote_test.go @@ -26,6 +26,22 @@ func TestNormalizeShareName(t *testing.T) { name: "generally good except for .", err: ErrInvalidShareName, }, + { + name: "c++", + want: "c++", + }, + { + name: " my lib (c++) ", + want: "my lib (c++)", + }, + { + name: "rhea+joe", + want: "joe+rhea", + }, + { + name: "Charlie+Alice+Bob", + want: "alice+bob+charlie", + }, } for _, tt := range tests { t.Run(fmt.Sprintf("name %q", tt.name), func(t *testing.T) { diff --git a/drive/share_access.go b/drive/share_access.go new file mode 100644 index 000000000..882b968d8 --- /dev/null +++ b/drive/share_access.go @@ -0,0 +1,235 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package drive + +import ( + "sort" + "strings" + + "tailscale.com/types/views" +) + +// ParseShareAccessNames returns the list of user short names encoded in a +// share name that uses '+' as a separator. Returns nil if the name is not +// a multi-user share. A valid multi-user share must have all non-empty +// segments and at least 2 segments (so "c++" with empty segments returns nil). +func ParseShareAccessNames(shareName string) []string { + if !strings.Contains(shareName, "+") { + return nil + } + parts := strings.Split(shareName, "+") + if len(parts) < 2 { + return nil + } + for _, p := range parts { + if p == "" { + return nil + } + } + return parts +} + +// NormalizeShareNameOrder sorts '+'-separated segments alphabetically. +// Non-multi-user names are returned unchanged. +func NormalizeShareNameOrder(name string) string { + parts := ParseShareAccessNames(name) + if parts == nil { + return name + } + sort.Strings(parts) + return strings.Join(parts, "+") +} + +// IsShareAccessibleByUser checks if the given loginName's short name (the +// part before '@') appears in the share's '+'-separated user list. Returns +// true for non-multi-user shares (no name-based restriction). +func IsShareAccessibleByUser(shareName, loginName string) bool { + parts := ParseShareAccessNames(shareName) + if parts == nil { + return true + } + short := LoginShortName(loginName) + domain := loginDomain(loginName) + for _, p := range parts { + segShort, segDomain := parseShareSegment(p) + if segShort != short { + continue + } + // If the segment has no domain qualifier, match on short name only + // (backward compat). If it has a domain, the login's domain must + // start with that label. + if segDomain == "" { + return true + } + if domain != "" && strings.HasPrefix(domain, segDomain) { + return true + } + } + return false +} + +// FilterPermissionsByIdentity takes ACL-derived permissions and further +// restricts them based on share name access control. For each share: +// - Contains '+' with valid segments: peer's login short name must be listed +// - Has IsGroup=true on the Share: peer must be in a matching group +// - Otherwise: no name-based restriction (ACLs only) +// +// The wildcard "*" permission is preserved but only applies to shares the +// peer can access based on name/group rules. +func FilterPermissionsByIdentity( + aclPerms Permissions, + loginName string, + groups []string, + shares views.SliceView[*Share, ShareView], +) Permissions { + // If there are no shares with name-based restrictions, return as-is. + hasRestricted := false + type shareInfo struct { + accessible bool + } + shareInfos := make(map[string]shareInfo, shares.Len()) + for i := range shares.Len() { + s := shares.At(i) + name := s.Name() + info := shareInfo{accessible: true} + if s.IsGroup() { + hasRestricted = true + info.accessible = matchesGroup(name, groups) + } else if parts := ParseShareAccessNames(name); parts != nil { + hasRestricted = true + info.accessible = false + short := LoginShortName(loginName) + domain := loginDomain(loginName) + for _, p := range parts { + segShort, segDomain := parseShareSegment(p) + if segShort != short { + continue + } + if segDomain == "" { + info.accessible = true + break + } + if domain != "" && strings.HasPrefix(domain, segDomain) { + info.accessible = true + break + } + } + } + shareInfos[name] = info + } + + if !hasRestricted { + return aclPerms + } + + // Expand the wildcard into per-share permissions so we can selectively + // deny access. The Permissions.For method returns max(specific, wildcard), + // so the only way to deny a share under a wildcard is to remove the + // wildcard and grant each accessible share explicitly. + wildcardPerm := aclPerms[wildcardShare] + + filtered := make(Permissions) + + // Copy non-wildcard ACL entries for accessible shares. + for shareName, perm := range aclPerms { + if shareName == wildcardShare { + continue + } + info, ok := shareInfos[shareName] + if !ok { + // Share in ACL but not on this node; keep it. + filtered[shareName] = perm + continue + } + if info.accessible { + filtered[shareName] = perm + } + } + + // If there was a wildcard, expand it to all accessible shares that + // don't already have an explicit (higher) permission. + if wildcardPerm > PermissionNone { + for name, info := range shareInfos { + if info.accessible { + if existing := filtered[name]; wildcardPerm > existing { + filtered[name] = wildcardPerm + } + } + } + } + + return filtered +} + +// LoginShortName extracts the short name from a login name. +// "joe@example.com" → "joe" +func LoginShortName(loginName string) string { + if i := strings.Index(loginName, "@"); i >= 0 { + return loginName[:i] + } + return loginName +} + +// loginDomain extracts the domain part from a login name. +// "alice@example.com" → "example.com" +// "alice" → "" +func loginDomain(loginName string) string { + if i := strings.Index(loginName, "@"); i >= 0 { + return loginName[i+1:] + } + return "" +} + +// LoginDisplayName returns a display name for a login, suitable for use in +// share names. If the login's domain matches tailnetDomain, only the short +// name is returned (e.g. "alice"). Otherwise, the format "shortname(domain)" +// is used (e.g. "alice(company)") where domain has its TLD stripped. +func LoginDisplayName(loginName, tailnetDomain string) string { + short := LoginShortName(loginName) + domain := loginDomain(loginName) + if domain == "" || domain == tailnetDomain { + return short + } + // Strip TLD from domain for brevity: "company.com" → "company" + domainLabel := domain + if i := strings.Index(domainLabel, "."); i >= 0 { + domainLabel = domainLabel[:i] + } + return short + "(" + domainLabel + ")" +} + +// parseShareSegment parses a share name segment that may contain a domain +// qualifier. "alice(company)" returns ("alice", "company"). "alice" returns +// ("alice", ""). +func parseShareSegment(segment string) (shortName, domain string) { + if i := strings.Index(segment, "("); i >= 0 { + if j := strings.Index(segment, ")"); j > i { + return segment[:i], segment[i+1 : j] + } + } + return segment, "" +} + +// matchesGroup checks if the share name matches any of the peer's group +// identifiers. Groups can be in the form "group:eng" or "eng@example.com". +func matchesGroup(shareName string, groups []string) bool { + for _, g := range groups { + if GroupShortName(g) == shareName { + return true + } + } + return false +} + +// GroupShortName extracts a short group name from a group identifier. +// "group:eng" → "eng", "eng@example.com" → "eng" +func GroupShortName(group string) string { + if strings.HasPrefix(group, "group:") { + return strings.TrimPrefix(group, "group:") + } + if i := strings.Index(group, "@"); i >= 0 { + return group[:i] + } + return group +} diff --git a/drive/share_access_test.go b/drive/share_access_test.go new file mode 100644 index 000000000..91840cf10 --- /dev/null +++ b/drive/share_access_test.go @@ -0,0 +1,303 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package drive + +import ( + "testing" + + "tailscale.com/types/views" +) + +func TestParseShareAccessNames(t *testing.T) { + tests := []struct { + name string + want []string + }{ + {"joe+rhea", []string{"joe", "rhea"}}, + {"alice+joe+rhea", []string{"alice", "joe", "rhea"}}, + {"c++", nil}, // empty segments + {"docs", nil}, // no '+' + {"+leading", nil}, // empty first segment + {"trailing+", nil}, // empty last segment + {"a++b", nil}, // empty middle segment + {"a+b", []string{"a", "b"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ParseShareAccessNames(tt.name) + if tt.want == nil { + if got != nil { + t.Errorf("ParseShareAccessNames(%q) = %v, want nil", tt.name, got) + } + return + } + if len(got) != len(tt.want) { + t.Errorf("ParseShareAccessNames(%q) = %v, want %v", tt.name, got, tt.want) + return + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("ParseShareAccessNames(%q)[%d] = %q, want %q", tt.name, i, got[i], tt.want[i]) + } + } + }) + } +} + +func TestNormalizeShareNameOrder(t *testing.T) { + tests := []struct { + name string + want string + }{ + {"rhea+joe", "joe+rhea"}, + {"charlie+alice+bob", "alice+bob+charlie"}, + {"docs", "docs"}, + {"c++", "c++"}, + {"a+b", "a+b"}, // already sorted + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NormalizeShareNameOrder(tt.name) + if got != tt.want { + t.Errorf("NormalizeShareNameOrder(%q) = %q, want %q", tt.name, got, tt.want) + } + }) + } +} + +func TestIsShareAccessibleByUser(t *testing.T) { + tests := []struct { + shareName string + loginName string + want bool + }{ + {"joe+rhea", "joe@example.com", true}, + {"joe+rhea", "rhea@example.com", true}, + {"joe+rhea", "alice@example.com", false}, + {"docs", "anyone@example.com", true}, // not a multi-user share + {"c++", "anyone@example.com", true}, // not a multi-user share (empty segments) + {"joe+rhea", "joe", true}, // no domain + + // name(domain) format + {"alice(contractor)+bob", "alice@contractor.io", true}, + {"alice(contractor)+bob", "alice@example.com", false}, // wrong domain + {"alice(contractor)+bob", "bob@example.com", true}, // bob has no domain qualifier + {"alice(contractor)+bob", "charlie@example.com", false}, // not listed + } + for _, tt := range tests { + t.Run(tt.shareName+"_"+tt.loginName, func(t *testing.T) { + got := IsShareAccessibleByUser(tt.shareName, tt.loginName) + if got != tt.want { + t.Errorf("IsShareAccessibleByUser(%q, %q) = %v, want %v", tt.shareName, tt.loginName, got, tt.want) + } + }) + } +} + +func TestLoginDisplayName(t *testing.T) { + tests := []struct { + loginName string + tailnetDomain string + want string + }{ + {"alice@example.com", "example.com", "alice"}, // home domain + {"alice@contractor.io", "example.com", "alice(contractor)"}, // foreign domain + {"alice@example.com", "bob@gmail.com", "alice(example)"}, // shared domain tailnet + {"alice", "example.com", "alice"}, // no domain in login + {"alice@foo.bar.com", "example.com", "alice(foo)"}, // multi-part domain + } + for _, tt := range tests { + t.Run(tt.loginName+"_"+tt.tailnetDomain, func(t *testing.T) { + got := LoginDisplayName(tt.loginName, tt.tailnetDomain) + if got != tt.want { + t.Errorf("LoginDisplayName(%q, %q) = %q, want %q", tt.loginName, tt.tailnetDomain, got, tt.want) + } + }) + } +} + +func TestParseShareSegment(t *testing.T) { + tests := []struct { + input string + wantShort string + wantDomain string + }{ + {"alice", "alice", ""}, + {"alice(company)", "alice", "company"}, + {"alice(contractor)", "alice", "contractor"}, + {"bob", "bob", ""}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + gotShort, gotDomain := parseShareSegment(tt.input) + if gotShort != tt.wantShort || gotDomain != tt.wantDomain { + t.Errorf("parseShareSegment(%q) = (%q, %q), want (%q, %q)", tt.input, gotShort, gotDomain, tt.wantShort, tt.wantDomain) + } + }) + } +} + +func TestLoginShortName(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"joe@example.com", "joe"}, + {"joe", "joe"}, + {"alice@foo.bar.com", "alice"}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := LoginShortName(tt.input) + if got != tt.want { + t.Errorf("LoginShortName(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestMatchesGroup(t *testing.T) { + tests := []struct { + shareName string + groups []string + want bool + }{ + {"eng", []string{"group:eng"}, true}, + {"eng", []string{"eng@example.com"}, true}, + {"eng", []string{"group:design", "group:eng"}, true}, + {"eng", []string{"group:design"}, false}, + {"eng", []string{}, false}, + {"design", []string{"engineering@example.com"}, false}, + } + for _, tt := range tests { + t.Run(tt.shareName, func(t *testing.T) { + got := matchesGroup(tt.shareName, tt.groups) + if got != tt.want { + t.Errorf("matchesGroup(%q, %v) = %v, want %v", tt.shareName, tt.groups, got, tt.want) + } + }) + } +} + +func TestGroupShortName(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"group:eng", "eng"}, + {"eng@example.com", "eng"}, + {"eng", "eng"}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := GroupShortName(tt.input) + if got != tt.want { + t.Errorf("GroupShortName(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestFilterPermissionsByIdentity(t *testing.T) { + shares := views.SliceOfViews([]*Share{ + {Name: "joe+rhea"}, + {Name: "docs"}, + {Name: "eng", IsGroup: true}, + {Name: "alice+bob"}, + }) + + t.Run("multi-user share access", func(t *testing.T) { + perms := Permissions{ + "*": PermissionReadWrite, + } + filtered := FilterPermissionsByIdentity(perms, "joe@example.com", nil, shares) + // joe can access joe+rhea and docs, but not eng (group) or alice+bob + if filtered.For("joe+rhea") != PermissionReadWrite { + t.Error("joe should access joe+rhea") + } + if filtered.For("docs") != PermissionReadWrite { + t.Error("joe should access docs") + } + if filtered.For("eng") != PermissionNone { + t.Error("joe should not access eng (not in group)") + } + if filtered.For("alice+bob") != PermissionNone { + t.Error("joe should not access alice+bob") + } + }) + + t.Run("group share access", func(t *testing.T) { + perms := Permissions{ + "*": PermissionReadOnly, + } + filtered := FilterPermissionsByIdentity(perms, "joe@example.com", []string{"group:eng"}, shares) + if filtered.For("eng") != PermissionReadOnly { + t.Error("joe in group:eng should access eng share") + } + }) + + t.Run("specific share permission without wildcard", func(t *testing.T) { + perms := Permissions{ + "joe+rhea": PermissionReadWrite, + "alice+bob": PermissionReadOnly, + } + filtered := FilterPermissionsByIdentity(perms, "joe@example.com", nil, shares) + if filtered.For("joe+rhea") != PermissionReadWrite { + t.Error("joe should have rw to joe+rhea") + } + if filtered.For("alice+bob") != PermissionNone { + t.Error("joe should not access alice+bob") + } + }) + + t.Run("no restricted shares means no filtering", func(t *testing.T) { + perms := Permissions{ + "*": PermissionReadWrite, + } + unrestricted := views.SliceOfViews([]*Share{ + {Name: "docs"}, + {Name: "photos"}, + }) + filtered := FilterPermissionsByIdentity(perms, "joe@example.com", nil, unrestricted) + if filtered.For("docs") != PermissionReadWrite { + t.Error("wildcard should pass through with no restricted shares") + } + }) + + t.Run("empty shares means no filtering", func(t *testing.T) { + perms := Permissions{ + "*": PermissionReadWrite, + } + empty := views.SliceOfViews([]*Share{}) + filtered := FilterPermissionsByIdentity(perms, "joe@example.com", nil, empty) + if filtered.For("anything") != PermissionReadWrite { + t.Error("wildcard should pass through with empty shares") + } + }) + + t.Run("name(domain) share access", func(t *testing.T) { + domainShares := views.SliceOfViews([]*Share{ + {Name: "alice(contractor)+bob"}, + {Name: "docs"}, + }) + perms := Permissions{ + "*": PermissionReadWrite, + } + // alice@contractor.io should access alice(contractor)+bob + filtered := FilterPermissionsByIdentity(perms, "alice@contractor.io", nil, domainShares) + if filtered.For("alice(contractor)+bob") != PermissionReadWrite { + t.Error("alice@contractor.io should access alice(contractor)+bob") + } + // alice@example.com should NOT access alice(contractor)+bob + filtered = FilterPermissionsByIdentity(perms, "alice@example.com", nil, domainShares) + if filtered.For("alice(contractor)+bob") != PermissionNone { + t.Error("alice@example.com should not access alice(contractor)+bob") + } + // bob@example.com should access alice(contractor)+bob + filtered = FilterPermissionsByIdentity(perms, "bob@example.com", nil, domainShares) + if filtered.For("alice(contractor)+bob") != PermissionReadWrite { + t.Error("bob@example.com should access alice(contractor)+bob") + } + }) +} diff --git a/ipn/ipnlocal/peerapi_drive.go b/ipn/ipnlocal/peerapi_drive.go index d42843577..193106c30 100644 --- a/ipn/ipnlocal/peerapi_drive.go +++ b/ipn/ipnlocal/peerapi_drive.go @@ -53,6 +53,9 @@ func handleServeDrive(hi PeerAPIHandler, w http.ResponseWriter, r *http.Request) return } + shares := h.ps.b.DriveGetShares() + p = drive.FilterPermissionsByIdentity(p, h.peerUser.LoginName, h.peerUser.Groups, shares) + fs, ok := h.ps.b.sys.DriveForRemote.GetOK() if !ok { h.logf("taildrive: not supported on platform")