diff --git a/dhcp6/packet_builder.go b/dhcp6/packet_builder.go index 987e38c..20c3427 100644 --- a/dhcp6/packet_builder.go +++ b/dhcp6/packet_builder.go @@ -6,53 +6,56 @@ import ( "net" ) +// PacketBuilder is used for generating responses to requests received from dhcp clients type PacketBuilder struct { PreferredLifetime uint32 ValidLifetime uint32 } +// MakePacketBuilder creates a new PacketBuilder and initializes it with preferred and valid lifetimes func MakePacketBuilder(preferredLifetime, validLifetime uint32) *PacketBuilder { return &PacketBuilder{PreferredLifetime: preferredLifetime, ValidLifetime: validLifetime} } +// BuildResponse generates a response packet for a packet received from a client func (b *PacketBuilder) BuildResponse(in *Packet, serverDUID []byte, configuration BootConfiguration, addresses AddressPool) (*Packet, error) { switch in.Type { case MsgSolicit: - bootFileURL, err := configuration.GetBootURL(b.ExtractLLAddressOrID(in.Options.ClientID()), in.Options.ClientArchType()) + bootFileURL, err := configuration.GetBootURL(b.extractLLAddressOrID(in.Options.ClientID()), in.Options.ClientArchType()) if err != nil { return nil, err } associations, err := addresses.ReserveAddresses(in.Options.ClientID(), in.Options.IaNaIDs()) if err != nil { - return b.MakeMsgAdvertiseWithNoAddrsAvailable(in.TransactionID, serverDUID, in.Options.ClientID(), err), err + return b.makeMsgAdvertiseWithNoAddrsAvailable(in.TransactionID, serverDUID, in.Options.ClientID(), err), err } - return b.MakeMsgAdvertise(in.TransactionID, serverDUID, in.Options.ClientID(), + return b.makeMsgAdvertise(in.TransactionID, serverDUID, in.Options.ClientID(), in.Options.ClientArchType(), associations, bootFileURL, configuration.GetPreference(), configuration.GetRecursiveDNS()), nil case MsgRequest: - bootFileURL, err := configuration.GetBootURL(b.ExtractLLAddressOrID(in.Options.ClientID()), in.Options.ClientArchType()) + bootFileURL, err := configuration.GetBootURL(b.extractLLAddressOrID(in.Options.ClientID()), in.Options.ClientArchType()) if err != nil { return nil, err } associations, err := addresses.ReserveAddresses(in.Options.ClientID(), in.Options.IaNaIDs()) - return b.MakeMsgReply(in.TransactionID, serverDUID, in.Options.ClientID(), + return b.makeMsgReply(in.TransactionID, serverDUID, in.Options.ClientID(), in.Options.ClientArchType(), associations, iasWithoutAddesses(associations, in.Options.IaNaIDs()), bootFileURL, configuration.GetRecursiveDNS(), err), err case MsgInformationRequest: - bootFileURL, err := configuration.GetBootURL(b.ExtractLLAddressOrID(in.Options.ClientID()), in.Options.ClientArchType()) + bootFileURL, err := configuration.GetBootURL(b.extractLLAddressOrID(in.Options.ClientID()), in.Options.ClientArchType()) if err != nil { return nil, err } - return b.MakeMsgInformationRequestReply(in.TransactionID, serverDUID, in.Options.ClientID(), + return b.makeMsgInformationRequestReply(in.TransactionID, serverDUID, in.Options.ClientID(), in.Options.ClientArchType(), bootFileURL, configuration.GetRecursiveDNS()), nil case MsgRelease: addresses.ReleaseAddresses(in.Options.ClientID(), in.Options.IaNaIDs()) - return b.MakeMsgReleaseReply(in.TransactionID, serverDUID, in.Options.ClientID()), nil + return b.makeMsgReleaseReply(in.TransactionID, serverDUID, in.Options.ClientID()), nil default: return nil, nil } } -func (b *PacketBuilder) MakeMsgAdvertise(transactionID [3]byte, serverDUID, clientID []byte, clientArchType uint16, +func (b *PacketBuilder) makeMsgAdvertise(transactionID [3]byte, serverDUID, clientID []byte, clientArchType uint16, associations []*IdentityAssociation, bootFileURL, preference []byte, dnsServers []net.IP) *Packet { retOptions := make(Options) retOptions.AddOption(MakeOption(OptClientID, clientID)) @@ -72,7 +75,7 @@ func (b *PacketBuilder) MakeMsgAdvertise(transactionID [3]byte, serverDUID, clie return &Packet{Type: MsgAdvertise, TransactionID: transactionID, Options: retOptions} } -func (b *PacketBuilder) MakeMsgReply(transactionID [3]byte, serverDUID, clientID []byte, clientArchType uint16, +func (b *PacketBuilder) makeMsgReply(transactionID [3]byte, serverDUID, clientID []byte, clientArchType uint16, associations []*IdentityAssociation, iasWithoutAddresses [][]byte, bootFileURL []byte, dnsServers []net.IP, err error) *Packet { retOptions := make(Options) retOptions.AddOption(MakeOption(OptClientID, clientID)) @@ -94,7 +97,7 @@ func (b *PacketBuilder) MakeMsgReply(transactionID [3]byte, serverDUID, clientID return &Packet{Type: MsgReply, TransactionID: transactionID, Options: retOptions} } -func (b *PacketBuilder) MakeMsgInformationRequestReply(transactionID [3]byte, serverDUID, clientID []byte, clientArchType uint16, +func (b *PacketBuilder) makeMsgInformationRequestReply(transactionID [3]byte, serverDUID, clientID []byte, clientArchType uint16, bootFileURL []byte, dnsServers []net.IP) *Packet { retOptions := make(Options) retOptions.AddOption(MakeOption(OptClientID, clientID)) @@ -108,7 +111,7 @@ func (b *PacketBuilder) MakeMsgInformationRequestReply(transactionID [3]byte, se return &Packet{Type: MsgReply, TransactionID: transactionID, Options: retOptions} } -func (b *PacketBuilder) MakeMsgReleaseReply(transactionID [3]byte, serverDUID, clientID []byte) *Packet { +func (b *PacketBuilder) makeMsgReleaseReply(transactionID [3]byte, serverDUID, clientID []byte) *Packet { retOptions := make(Options) retOptions.AddOption(MakeOption(OptClientID, clientID)) @@ -120,7 +123,7 @@ func (b *PacketBuilder) MakeMsgReleaseReply(transactionID [3]byte, serverDUID, c return &Packet{Type: MsgReply, TransactionID: transactionID, Options: retOptions} } -func (b *PacketBuilder) MakeMsgAdvertiseWithNoAddrsAvailable(transactionID [3]byte, serverDUID, clientID []byte, err error) *Packet { +func (b *PacketBuilder) makeMsgAdvertiseWithNoAddrsAvailable(transactionID [3]byte, serverDUID, clientID []byte, err error) *Packet { retOptions := make(Options) retOptions.AddOption(MakeOption(OptClientID, clientID)) retOptions.AddOption(MakeOption(OptServerID, serverDUID)) @@ -136,7 +139,7 @@ func (b *PacketBuilder) calculateT2() uint32 { return (b.PreferredLifetime * 4)/5 } -func (b *PacketBuilder) ExtractLLAddressOrID(optClientID []byte) []byte { +func (b *PacketBuilder) extractLLAddressOrID(optClientID []byte) []byte { idType := binary.BigEndian.Uint16(optClientID[0:2]) switch idType { case 1: diff --git a/dhcp6/packet_builder_test.go b/dhcp6/packet_builder_test.go index f30e513..4a149c8 100644 --- a/dhcp6/packet_builder_test.go +++ b/dhcp6/packet_builder_test.go @@ -19,7 +19,7 @@ func TestMakeMsgAdvertise(t *testing.T) { builder := MakePacketBuilder(90, 100) - msg := builder.MakeMsgAdvertise(transactionID, expectedServerID, expectedClientID, 0x11, + msg := builder.makeMsgAdvertise(transactionID, expectedServerID, expectedClientID, 0x11, []*IdentityAssociation{identityAssociation}, expectedBootFileURL, nil, []net.IP{expectedDNSServerIP}) if msg.Type != MsgAdvertise { @@ -83,7 +83,7 @@ func TestMakeMsgAdvertiseShouldSkipDnsServersIfNoneConfigured(t *testing.T) { builder := MakePacketBuilder(90, 100) - msg := builder.MakeMsgAdvertise(transactionID, expectedServerID, expectedClientID, 0x11, + msg := builder.makeMsgAdvertise(transactionID, expectedServerID, expectedClientID, 0x11, []*IdentityAssociation{identityAssociation}, expectedBootFileURL, nil, []net.IP{}) _, exists := msg.Options[OptRecursiveDNS]; if exists { @@ -97,7 +97,7 @@ func TestShouldSetPreferenceOptionWhenSpecified(t *testing.T) { builder := MakePacketBuilder(90, 100) expectedPreference := []byte{128} - msg := builder.MakeMsgAdvertise([3]byte{'t', 'i', 'd'}, []byte("serverid"), []byte("clientid"), 0x11, + msg := builder.makeMsgAdvertise([3]byte{'t', 'i', 'd'}, []byte("serverid"), []byte("clientid"), 0x11, []*IdentityAssociation{identityAssociation}, []byte("http://bootfileurl"), expectedPreference, []net.IP{}) preferenceOption := msg.Options[OptPreference] @@ -119,7 +119,7 @@ func TestMakeMsgAdvertiseWithHttpClientArch(t *testing.T) { builder := MakePacketBuilder(90, 100) - msg := builder.MakeMsgAdvertise(transactionID, expectedServerID, expectedClientID, 0x10, + msg := builder.makeMsgAdvertise(transactionID, expectedServerID, expectedClientID, 0x10, []*IdentityAssociation{identityAssociation}, expectedBootFileURL, nil, []net.IP{}) vendorClassOption := msg.Options[OptVendorClass] @@ -143,7 +143,7 @@ func TestMakeNoAddrsAvailable(t *testing.T) { builder := MakePacketBuilder(90, 100) - msg := builder.MakeMsgAdvertiseWithNoAddrsAvailable(transactionID, expectedServerID, expectedClientID, fmt.Errorf(expectedMessage)) + msg := builder.makeMsgAdvertiseWithNoAddrsAvailable(transactionID, expectedServerID, expectedClientID, fmt.Errorf(expectedMessage)) if msg.Type != MsgAdvertise { t.Fatalf("Expected message type %d, got %d", MsgAdvertise, msg.Type) @@ -191,7 +191,7 @@ func TestMakeMsgReply(t *testing.T) { builder := MakePacketBuilder(90, 100) - msg := builder.MakeMsgReply(transactionID, expectedServerID, expectedClientID, 0x11, + msg := builder.makeMsgReply(transactionID, expectedServerID, expectedClientID, 0x11, []*IdentityAssociation{identityAssociation}, make([][]byte, 0), expectedBootFileURL, []net.IP{expectedDNSServerIP}, nil) if msg.Type != MsgReply { @@ -255,7 +255,7 @@ func TestMakeMsgReplyShouldSkipDnsServersIfNoneWereConfigured(t *testing.T) { builder := MakePacketBuilder(90, 100) - msg := builder.MakeMsgReply(transactionID, expectedServerID, expectedClientID, 0x11, + msg := builder.makeMsgReply(transactionID, expectedServerID, expectedClientID, 0x11, []*IdentityAssociation{identityAssociation}, make([][]byte, 0), expectedBootFileURL, []net.IP{}, nil) _, exists := msg.Options[OptRecursiveDNS]; if exists { @@ -273,7 +273,7 @@ func TestMakeMsgReplyWithHttpClientArch(t *testing.T) { builder := MakePacketBuilder(90, 100) - msg := builder.MakeMsgReply(transactionID, expectedServerID, expectedClientID, 0x10, + msg := builder.makeMsgReply(transactionID, expectedServerID, expectedClientID, 0x10, []*IdentityAssociation{identityAssociation}, make([][]byte, 0), expectedBootFileURL, []net.IP{}, nil) vendorClassOption := msg.Options[OptVendorClass] @@ -301,7 +301,7 @@ func TestMakeMsgReplyWithNoAddrsAvailable(t *testing.T) { builder := MakePacketBuilder(90, 100) - msg := builder.MakeMsgReply(transactionID, expectedServerID, expectedClientID, 0x10, + msg := builder.makeMsgReply(transactionID, expectedServerID, expectedClientID, 0x10, []*IdentityAssociation{identityAssociation}, [][]byte{[]byte("id-2")}, expectedBootFileURL, []net.IP{}, fmt.Errorf(expectedErrorMessage)) @@ -353,7 +353,7 @@ func TestMakeMsgInformationRequestReply(t *testing.T) { builder := MakePacketBuilder(90, 100) - msg := builder.MakeMsgInformationRequestReply(transactionID, expectedServerID, expectedClientID, 0x11, + msg := builder.makeMsgInformationRequestReply(transactionID, expectedServerID, expectedClientID, 0x11, expectedBootFileURL, []net.IP{expectedDNSServerIP}) if msg.Type != MsgReply { @@ -410,7 +410,7 @@ func TestMakeMsgInformationRequestReplyShouldSkipDnsServersIfNoneWereConfigured( builder := MakePacketBuilder(90, 100) - msg := builder.MakeMsgInformationRequestReply(transactionID, expectedServerID, expectedClientID, 0x11, + msg := builder.makeMsgInformationRequestReply(transactionID, expectedServerID, expectedClientID, 0x11, expectedBootFileURL, []net.IP{}) _, exists := msg.Options[OptRecursiveDNS]; if exists { @@ -426,7 +426,7 @@ func TestMakeMsgInformationRequestReplyWithHttpClientArch(t *testing.T) { builder := MakePacketBuilder(90, 100) - msg := builder.MakeMsgInformationRequestReply(transactionID, expectedServerID, expectedClientID, 0x10, + msg := builder.makeMsgInformationRequestReply(transactionID, expectedServerID, expectedClientID, 0x10, expectedBootFileURL, []net.IP{}) vendorClassOption := msg.Options[OptVendorClass] @@ -450,7 +450,7 @@ func TestMakeMsgReleaseReply(t *testing.T) { builder := MakePacketBuilder(90, 100) - msg := builder.MakeMsgReleaseReply(transactionID, expectedServerID, expectedClientID) + msg := builder.makeMsgReleaseReply(transactionID, expectedServerID, expectedClientID) if msg.Type != MsgReply { t.Fatalf("Expected message type %d, got %d", MsgAdvertise, msg.Type) @@ -485,7 +485,7 @@ func TestMakeMsgReleaseReply(t *testing.T) { func TestExtractLLAddressOrIdWithDUIDLLT(t *testing.T) { builder := &PacketBuilder{} expectedLLAddress := []byte{0xac, 0xbc, 0x32, 0xae, 0x86, 0x37} - llAddress := builder.ExtractLLAddressOrID([]byte{0x0, 0x1, 0x0, 0x1, 0x1, 0x2, 0x3, 0x4, 0xac, 0xbc, 0x32, 0xae, 0x86, 0x37}) + llAddress := builder.extractLLAddressOrID([]byte{0x0, 0x1, 0x0, 0x1, 0x1, 0x2, 0x3, 0x4, 0xac, 0xbc, 0x32, 0xae, 0x86, 0x37}) if string(expectedLLAddress) != string(llAddress) { t.Fatalf("Expected ll address %x, got: %x", expectedLLAddress, llAddress) } @@ -494,7 +494,7 @@ func TestExtractLLAddressOrIdWithDUIDLLT(t *testing.T) { func TestExtractLLAddressOrIdWithDUIDEN(t *testing.T) { builder := &PacketBuilder{} expectedID := []byte{0x0, 0x1, 0x2, 0x3, 0xac, 0xbc, 0x32, 0xae, 0x86, 0x37} - id := builder.ExtractLLAddressOrID([]byte{0x0, 0x2, 0x0, 0x1, 0x2, 0x3, 0xac, 0xbc, 0x32, 0xae, 0x86, 0x37}) + id := builder.extractLLAddressOrID([]byte{0x0, 0x2, 0x0, 0x1, 0x2, 0x3, 0xac, 0xbc, 0x32, 0xae, 0x86, 0x37}) if string(expectedID) != string(id) { t.Fatalf("Expected id %x, got: %x", expectedID, id) } @@ -503,7 +503,7 @@ func TestExtractLLAddressOrIdWithDUIDEN(t *testing.T) { func TestExtractLLAddressOrIdWithDUIDLL(t *testing.T) { builder := &PacketBuilder{} expectedLLAddress := []byte{0xac, 0xbc, 0x32, 0xae, 0x86, 0x37} - llAddress := builder.ExtractLLAddressOrID([]byte{0x0, 0x3, 0x0, 0x1, 0xac, 0xbc, 0x32, 0xae, 0x86, 0x37}) + llAddress := builder.extractLLAddressOrID([]byte{0x0, 0x3, 0x0, 0x1, 0xac, 0xbc, 0x32, 0xae, 0x86, 0x37}) if string(expectedLLAddress) != string(llAddress) { t.Fatalf("Expected ll address %x, got: %x", expectedLLAddress, llAddress) } diff --git a/dhcp6/random_address_pool.go b/dhcp6/random_address_pool.go index fbeb208..8505feb 100644 --- a/dhcp6/random_address_pool.go +++ b/dhcp6/random_address_pool.go @@ -10,49 +10,52 @@ import ( "fmt" ) -type AssociationExpiration struct { +type associationExpiration struct { expiresAt time.Time ia *IdentityAssociation } -type Fifo struct {q []interface{}} +type fifo struct {q []interface{}} -func newFifo() Fifo { - return Fifo{q: make([]interface{}, 0, 1000)} +func newFifo() fifo { + return fifo{q: make([]interface{}, 0, 1000)} } -func (f *Fifo) Push(v interface{}) { +func (f *fifo) Push(v interface{}) { f.q = append(f.q, v) } -func (f *Fifo) Shift() interface{} { +func (f *fifo) Shift() interface{} { var ret interface{} ret, f.q = f.q[0], f.q[1:] return ret } -func (f *Fifo) Size() int { +func (f *fifo) Size() int { return len(f.q) } -func (f *Fifo) Peek() interface{} { +func (f *fifo) Peek() interface{} { if len(f.q) == 0 { return nil } return f.q[0] } +// RandomAddressPool that returns a random IP address from a pool of available addresses type RandomAddressPool struct { poolStartAddress *big.Int - poolSize uint64 + poolSize uint64 identityAssociations map[uint64]*IdentityAssociation usedIps map[uint64]struct{} - identityAssociationExpirations Fifo + identityAssociationExpirations fifo validLifetime uint32 // in seconds timeNow func() time.Time lock sync.Mutex } +// NewRandomAddressPool creates a new RandomAddressPool using pool start IP address, pool size, and valid lifetime of +// interface associations func NewRandomAddressPool(poolStartAddress net.IP, poolSize uint64, validLifetime uint32) *RandomAddressPool { ret := &RandomAddressPool{} ret.validLifetime = validLifetime @@ -75,6 +78,7 @@ func NewRandomAddressPool(poolStartAddress net.IP, poolSize uint64, validLifetim return ret } +// ReserveAddresses creates new or retrieves active associations for interfaces in interfaceIDs list. func (p *RandomAddressPool) ReserveAddresses(clientID []byte, interfaceIDs [][]byte) ([]*IdentityAssociation, error) { p.lock.Lock() defer p.lock.Unlock() @@ -107,7 +111,7 @@ func (p *RandomAddressPool) ReserveAddresses(clientID []byte, interfaceIDs [][]b CreatedAt: timeNow} p.identityAssociations[clientIDHash] = association p.usedIps[newIP.Uint64()] = struct{}{} - p.identityAssociationExpirations.Push(&AssociationExpiration{expiresAt: p.calculateAssociationExpiration(timeNow), ia: association}) + p.identityAssociationExpirations.Push(&associationExpiration{expiresAt: p.calculateAssociationExpiration(timeNow), ia: association}) ret = append(ret, association) break } @@ -117,6 +121,7 @@ func (p *RandomAddressPool) ReserveAddresses(clientID []byte, interfaceIDs [][]b return ret, nil } +// ReleaseAddresses returns IP addresses associated with ClientID and interfaceIDs back into the address pool func (p *RandomAddressPool) ReleaseAddresses(clientID []byte, interfaceIDs [][]byte) { p.lock.Lock() defer p.lock.Unlock() @@ -131,13 +136,15 @@ func (p *RandomAddressPool) ReleaseAddresses(clientID []byte, interfaceIDs [][] } } +// ExpireIdentityAssociations releases IP addresses in identity associations that reached the end of valid lifetime +// back into the address pool func (p *RandomAddressPool) ExpireIdentityAssociations() { p.lock.Lock() defer p.lock.Unlock() for { if p.identityAssociationExpirations.Size() < 1 { break } - expiration := p.identityAssociationExpirations.Peek().(*AssociationExpiration) + expiration := p.identityAssociationExpirations.Peek().(*associationExpiration) if p.timeNow().Before(expiration.expiresAt) { break } p.identityAssociationExpirations.Shift() delete(p.identityAssociations, p.calculateIAIDHash(expiration.ia.ClientID, expiration.ia.InterfaceID)) diff --git a/dhcp6/random_address_pool_test.go b/dhcp6/random_address_pool_test.go index ec25a70..4b6035a 100644 --- a/dhcp6/random_address_pool_test.go +++ b/dhcp6/random_address_pool_test.go @@ -95,7 +95,7 @@ func TestReserveAddressKeepsTrackOfAssociationExpiration(t *testing.T) { pool.timeNow = func() time.Time { return expectedTime } pool.ReserveAddresses(expectedClientID, [][]byte{expectedIAID}) - expiration := pool.identityAssociationExpirations.Peek().(*AssociationExpiration) + expiration := pool.identityAssociationExpirations.Peek().(*associationExpiration) if expiration == nil { t.Fatal("Expected an identity association expiration, but got nil") }