diff --git a/dhcp6/address_pool.go b/dhcp6/address_pool.go new file mode 100644 index 0000000..7328753 --- /dev/null +++ b/dhcp6/address_pool.go @@ -0,0 +1,33 @@ +package dhcp6 + +import ( + "net" + "time" +) + +type IdentityAssociation struct { + ipAddress net.IP + clientId []byte + interfaceId []byte + createdAt time.Time + t1 uint32 + t2 uint32 +} + +type AddressPool interface { + ReserveAddress(clientId, interfaceId []byte) *IdentityAssociation + ReleaseAddress(clientId, interfaceId []byte, addr net.IP) +} + + + + + + + + + + + + + diff --git a/dhcp6/conn.go b/dhcp6/conn.go index 6f4ab9e..c4ee009 100644 --- a/dhcp6/conn.go +++ b/dhcp6/conn.go @@ -70,7 +70,7 @@ func InterfaceIndexByAddress(ifAddr string) (*net.Interface, error) { return nil, fmt.Errorf("Error getting network interface address information: %s", err) } for _, addr := range addrs { - if addr.String() == ifAddr { + if addrToIP(addr).String() == ifAddr { return &ifi, nil } } @@ -78,10 +78,22 @@ func InterfaceIndexByAddress(ifAddr string) (*net.Interface, error) { return nil, fmt.Errorf("Couldn't find an interface with address %s", ifAddr) } +func addrToIP(a net.Addr) net.IP { + var ip net.IP + switch v := a.(type) { + case *net.IPAddr: + ip = v.IP + case *net.IPNet: + ip = v.IP + } + + return ip +} + func (c *Conn) RecvDHCP() (*Packet, net.IP, error) { b := make([]byte, 1500) for { - _, rcm, _, err := c.conn.ReadFrom(b) + n, rcm, _, err := c.conn.ReadFrom(b) if err != nil { return nil, nil, err } @@ -91,7 +103,7 @@ func (c *Conn) RecvDHCP() (*Packet, net.IP, error) { if !rcm.Dst.IsMulticast() || !rcm.Dst.Equal(c.group) { continue // unknown group, discard } - pkt, err := MakePacket(b) + pkt, err := MakePacket(b, n) if err != nil { return nil, nil, err } diff --git a/dhcp6/options.go b/dhcp6/options.go index 94305fa..2bd891a 100644 --- a/dhcp6/options.go +++ b/dhcp6/options.go @@ -40,6 +40,10 @@ type Option struct { Value []byte } +func MakeOption(id uint16, value []byte) *Option { + return &Option{ Id: id, Length: uint16(len(value)), Value: value} +} + type Options map[uint16]*Option func MakeOptions(bs []byte) (Options, error) { @@ -71,11 +75,48 @@ func MakeOptions(bs []byte) (Options, error) { func (o Options) HumanReadable() []string { to_ret := make([]string, 0, len(o)) for _, opt := range(o) { - to_ret = append(to_ret, fmt.Sprintf("Option: %d | %d | %d | %s\n", opt.Id, opt.Length, opt.Value, opt.Value)) + switch opt.Id { + case 3: + to_ret = append(to_ret, o.HumanReadableIaNa(*opt)...) + default: + to_ret = append(to_ret, fmt.Sprintf("Option: %d | %d | %d | %s\n", opt.Id, opt.Length, opt.Value, opt.Value)) + } } return to_ret } +func (o Options) HumanReadableIaNa(opt Option) []string { + to_ret := make([]string, 0) + to_ret = append(to_ret, fmt.Sprintf("Option: OptIaNa | len %d | iaid %x | t1 %d | t2 %d\n", + opt.Length, opt.Value[0:4], binary.BigEndian.Uint32(opt.Value[4:8]), binary.BigEndian.Uint32(opt.Value[8:12]))) + + if opt.Length <= 12 { + return to_ret // no options + } + + iaOptions := opt.Value[12:] + for len(iaOptions) > 0 { + l := uint16(binary.BigEndian.Uint16(iaOptions[2:4])) + id := uint16(binary.BigEndian.Uint16(iaOptions[0:2])) + + + switch id { + case OptIaAddr: + ip := make(net.IP, 16) + copy(ip, iaOptions[4:20]) + to_ret = append(to_ret, fmt.Sprintf("\tOption: IA_ADDR | len %d | ip %s | preferred %d | valid %d | %v \n", + l, ip, binary.BigEndian.Uint32(iaOptions[20:24]), binary.BigEndian.Uint32(iaOptions[24:28]), iaOptions[28:4+l])) + default: + to_ret = append(to_ret, fmt.Sprintf("\tOption: id %d | len %d | %s\n", + id, l, iaOptions[4:4+l])) + } + + iaOptions = iaOptions[4+l:] + } + + return to_ret +} + func (o Options) AddOption(option *Option) { o[option.Id] = option } @@ -87,7 +128,7 @@ func MakeIaNaOption(iaid []byte, t1, t2 uint32, iaAddr *Option) (*Option) { binary.BigEndian.PutUint32(value[4:], t1) binary.BigEndian.PutUint32(value[8:], t2) copy(value[12:], serializedIaAddr) - return &Option{Id: OptIaNa, Length: uint16(len(value)), Value: value} + return MakeOption(OptIaNa, value) } func MakeIaAddrOption(addr net.IP, preferredLifetime, validLifetime uint32) (*Option) { @@ -95,7 +136,7 @@ func MakeIaAddrOption(addr net.IP, preferredLifetime, validLifetime uint32) (*Op copy(value[0:], addr) binary.BigEndian.PutUint32(value[16:], preferredLifetime) binary.BigEndian.PutUint32(value[20:], validLifetime) - return &Option{ Id: OptIaAddr, Length: uint16(len(value)), Value: value} + return MakeOption(OptIaAddr, value) } func (o Options) Marshal() ([]byte, error) { diff --git a/dhcp6/packet.go b/dhcp6/packet.go index 6f82d10..0a7a3c8 100644 --- a/dhcp6/packet.go +++ b/dhcp6/packet.go @@ -2,7 +2,6 @@ package dhcp6 import ( "fmt" - "net" "encoding/binary" "bytes" ) @@ -31,8 +30,8 @@ type Packet struct { Options Options } -func MakePacket(bs []byte) (*Packet, error) { - options, err := MakeOptions(bs[4:]) // 4:len? +func MakePacket(bs []byte, packetLength int) (*Packet, error) { + options, err := MakeOptions(bs[4:packetLength]) if err != nil { return nil, fmt.Errorf("packet has malformed options section: %s", err) } @@ -48,24 +47,26 @@ func (p *Packet) Marshal() ([]byte, error) { } ret := make([]byte, len(marshalled_options) + 4, len(marshalled_options) + 4) - ret[0] = byte(MsgAdvertise) + ret[0] = byte(p.Type) copy(ret[1:], p.TransactionID[:]) copy(ret[4:], marshalled_options) return ret, nil } -func (p *Packet) BuildResponse(serverDuid []byte) *Packet { +func (p *Packet) BuildResponse(serverDuid []byte, addressPool AddressPool) *Packet { transactionId := p.TransactionID clientId := p.Options[OptClientId].Value iaNaId := p.Options[OptIaNa].Value[0:4] - clientArchType := p.Options[OptClientArchType].Value - + var clientArchType []byte + o, exists := p.Options[OptClientArchType]; if exists { + clientArchType = o.Value + } switch p.Type { case MsgSolicit: - return MakeMsgAdvertise(transactionId, serverDuid, clientId, iaNaId, clientArchType) + return MakeMsgAdvertise(transactionId, serverDuid, clientId, iaNaId, clientArchType, addressPool) case MsgRequest: - return MakeMsgReply(transactionId, serverDuid, clientId, iaNaId, clientArchType) + return MakeMsgReply(transactionId, serverDuid, clientId, iaNaId, clientArchType, addressPool) case MsgInformationRequest: return MakeMsgInformationRequestReply(transactionId, serverDuid, clientId, clientArchType) case MsgRelease: @@ -75,19 +76,19 @@ func (p *Packet) BuildResponse(serverDuid []byte) *Packet { } } -func MakeMsgAdvertise(transactionId [3]byte, serverDuid, clientId, iaId, clientArchType []byte) *Packet { +func MakeMsgAdvertise(transactionId [3]byte, serverDuid, clientId, iaId, clientArchType []byte, addressPool AddressPool) *Packet { ret_options := make(Options) - - ret_options.AddOption(&Option{Id: OptClientId, Length: uint16(len(clientId)), Value: clientId}) - ret_options.AddOption(MakeIaNaOption(iaId, 0, 0, - MakeIaAddrOption(net.ParseIP("2001:db8:f00f:cafe::99"), 27000, 43200))) - ret_options.AddOption(&Option{Id: OptServerId, Length: uint16(len(serverDuid)), Value: serverDuid}) + ret_options.AddOption(MakeOption(OptClientId, clientId)) + association := addressPool.ReserveAddress(clientId, iaId) + ret_options.AddOption(MakeIaNaOption(iaId, association.t1, association.t2, + MakeIaAddrOption(association.ipAddress, 27000, 43200))) + ret_options.AddOption(MakeOption(OptServerId, serverDuid)) if 0x10 == binary.BigEndian.Uint16(clientArchType) { // HTTPClient - ret_options.AddOption(&Option{Id: OptVendorClass, Length: 16, Value: []byte {0, 0, 0, 0, 0, 10, 72, 84, 84, 80, 67, 108, 105, 101, 110, 116}}) // HTTPClient - ret_options.AddOption(&Option{Id: OptBootfileUrl, Length: 42, Value: []byte("http://[2001:db8:f00f:cafe::4]/bootx64.efi")}) + ret_options.AddOption(MakeOption(OptVendorClass, []byte {0, 0, 0, 0, 0, 10, 72, 84, 84, 80, 67, 108, 105, 101, 110, 116})) // HTTPClient + ret_options.AddOption(MakeOption(OptBootfileUrl, []byte("http://[2001:db8:f00f:cafe::4]/bootx64.efi"))) } else { - ret_options.AddOption(&Option{Id: OptBootfileUrl, Length: 42, Value: []byte("http://[2001:db8:f00f:cafe::4]/script.ipxe")}) + ret_options.AddOption(MakeOption(OptBootfileUrl, []byte("http://[2001:db8:f00f:cafe::4]/script.ipxe"))) } // ret_options.AddOption(OptRecursiveDns, net.ParseIP("2001:db8:f00f:cafe::1")) //ret_options.AddOption(OptBootfileParam, []byte("http://") @@ -98,19 +99,20 @@ func MakeMsgAdvertise(transactionId [3]byte, serverDuid, clientId, iaId, clientA // TODO: OptClientArchType may not be present -func MakeMsgReply(transactionId [3]byte, serverDuid, clientId, iaId, clientArchType []byte) *Packet { +func MakeMsgReply(transactionId [3]byte, serverDuid, clientId, iaId, clientArchType []byte, addressPool AddressPool) *Packet { ret_options := make(Options) - ret_options.AddOption(&Option{Id: OptClientId, Length: uint16(len(clientId)), Value: clientId}) - ret_options.AddOption(MakeIaNaOption(iaId, 0, 0, - MakeIaAddrOption(net.ParseIP("2001:db8:f00f:cafe::99"), 27000, 43200))) - ret_options.AddOption(&Option{Id: OptServerId, Length: uint16(len(serverDuid)), Value: serverDuid}) + ret_options.AddOption(MakeOption(OptClientId, clientId)) + association := addressPool.ReserveAddress(clientId, iaId) + ret_options.AddOption(MakeIaNaOption(iaId, association.t1, association.t2, + MakeIaAddrOption(association.ipAddress, 27000, 43200))) + ret_options.AddOption(MakeOption(OptServerId, serverDuid)) // ret_options.AddOption(OptRecursiveDns, net.ParseIP("2001:db8:f00f:cafe::1")) if 0x10 == binary.BigEndian.Uint16(clientArchType) { // HTTPClient - ret_options.AddOption(&Option{Id: OptVendorClass, Length: 16, Value: []byte {0, 0, 0, 0, 0, 10, 72, 84, 84, 80, 67, 108, 105, 101, 110, 116}}) // HTTPClient - ret_options.AddOption(&Option{Id: OptBootfileUrl, Length: 42, Value: []byte("http://[2001:db8:f00f:cafe::4]/bootx64.efi")}) + ret_options.AddOption(MakeOption(OptVendorClass, []byte {0, 0, 0, 0, 0, 10, 72, 84, 84, 80, 67, 108, 105, 101, 110, 116})) // HTTPClient + ret_options.AddOption(MakeOption(OptBootfileUrl, []byte("http://[2001:db8:f00f:cafe::4]/bootx64.efi"))) } else { - ret_options.AddOption(&Option{Id: OptBootfileUrl, Length: 42, Value: []byte("http://[2001:db8:f00f:cafe::4]/script.ipxe")}) + ret_options.AddOption(MakeOption(OptBootfileUrl, []byte("http://[2001:db8:f00f:cafe::4]/script.ipxe"))) } return &Packet{Type: MsgReply, TransactionID: transactionId, Options: ret_options} @@ -118,14 +120,14 @@ func MakeMsgReply(transactionId [3]byte, serverDuid, clientId, iaId, clientArchT func MakeMsgInformationRequestReply(transactionId [3]byte, serverDuid, clientId, clientArchType []byte) *Packet { ret_options := make(Options) - ret_options.AddOption(&Option{Id: OptClientId, Length: uint16(len(clientId)), Value: clientId}) - ret_options.AddOption(&Option{Id: OptServerId, Length: uint16(len(serverDuid)), Value: serverDuid}) + ret_options.AddOption(MakeOption(OptClientId, clientId)) + ret_options.AddOption(MakeOption(OptServerId, serverDuid)) // ret_options.AddOption(OptRecursiveDns, net.ParseIP("2001:db8:f00f:cafe::1")) if 0x10 == binary.BigEndian.Uint16(clientArchType) { // HTTPClient - ret_options.AddOption(&Option{Id: OptVendorClass, Length: 16, Value: []byte{0, 0, 0, 0, 0, 10, 72, 84, 84, 80, 67, 108, 105, 101, 110, 116}}) // HTTPClient - ret_options.AddOption(&Option{Id: OptBootfileUrl, Length: 42, Value: []byte("http://[2001:db8:f00f:cafe::4]/bootx64.efi")}) + ret_options.AddOption(MakeOption(OptVendorClass, []byte{0, 0, 0, 0, 0, 10, 72, 84, 84, 80, 67, 108, 105, 101, 110, 116})) // HTTPClient + ret_options.AddOption(MakeOption(OptBootfileUrl, []byte("http://[2001:db8:f00f:cafe::4]/bootx64.efi"))) } else { - ret_options.AddOption(&Option{Id: OptBootfileUrl, Length: 42, Value: []byte("http://[2001:db8:f00f:cafe::4]/script.ipxe")}) + ret_options.AddOption(MakeOption(OptBootfileUrl, []byte("http://[2001:db8:f00f:cafe::4]/script.ipxe"))) } return &Packet{Type: MsgReply, TransactionID: transactionId, Options: ret_options} @@ -134,11 +136,11 @@ func MakeMsgInformationRequestReply(transactionId [3]byte, serverDuid, clientId, func MakeMsgReleaseReply(transactionId [3]byte, serverDuid, clientId []byte) *Packet { ret_options := make(Options) - ret_options.AddOption(&Option{Id: OptClientId, Length: uint16(len(clientId)), Value: clientId}) - ret_options.AddOption(&Option{Id: OptServerId, Length: uint16(len(serverDuid)), Value: serverDuid}) + ret_options.AddOption(MakeOption(OptClientId, clientId)) + ret_options.AddOption(MakeOption(OptServerId, serverDuid)) v := make([]byte, 19, 19) copy(v[2:], []byte("Release received.")) - ret_options.AddOption(&Option{Id: OptStatusCode, Length: uint16(len(v)), Value: v}) + ret_options.AddOption(MakeOption(OptStatusCode, v)) return &Packet{Type: MsgReply, TransactionID: transactionId, Options: ret_options} } diff --git a/dhcp6/packet_test.go b/dhcp6/packet_test.go index f7302d4..7885aa2 100644 --- a/dhcp6/packet_test.go +++ b/dhcp6/packet_test.go @@ -3,6 +3,7 @@ package dhcp6 import ( "testing" "encoding/binary" + "net" ) func TestMakeMsgAdvertise(t *testing.T) { @@ -10,7 +11,8 @@ func TestMakeMsgAdvertise(t *testing.T) { expectedServerId := []byte("serverid") transactionId := [3]byte{'1', '2', '3'} - msg := MakeMsgAdvertise(transactionId, expectedServerId, expectedClientId, []byte("1234"), []byte("11")) + msg := MakeMsgAdvertise(transactionId, expectedServerId, expectedClientId, []byte("1234"), []byte("11"), + NewRandomAddressPool(net.ParseIP("2001:db8:f00f:cafe::1"), net.ParseIP("2001:db8:f00f:cafe::1"), 100)) if msg.Type != MsgAdvertise { t.Fatalf("Expected message type %d, got %d", MsgAdvertise, msg.Type) @@ -57,7 +59,8 @@ func TestMakeMsgAdvertiseWithHttpClientArch(t *testing.T) { expectedServerId := []byte("serverid") transactionId := [3]byte{'1', '2', '3'} - msg := MakeMsgAdvertise(transactionId, expectedServerId, expectedClientId, []byte("1234"), []byte{0x0, 0x10}) + msg := MakeMsgAdvertise(transactionId, expectedServerId, expectedClientId, []byte("1234"), []byte{0x0, 0x10}, + NewRandomAddressPool(net.ParseIP("2001:db8:f00f:cafe::1"), net.ParseIP("2001:db8:f00f:cafe::1"), 100)) vendorClassOption := msg.Options[OptVendorClass] if vendorClassOption == nil { @@ -75,7 +78,8 @@ func TestMakeMsgReply(t *testing.T) { expectedServerId := []byte("serverid") transactionId := [3]byte{'1', '2', '3'} - msg := MakeMsgReply(transactionId, expectedServerId, expectedClientId, []byte("1234"), []byte("11")) + msg := MakeMsgReply(transactionId, expectedServerId, expectedClientId, []byte("1234"), []byte("11"), + NewRandomAddressPool(net.ParseIP("2001:db8:f00f:cafe::1"), net.ParseIP("2001:db8:f00f:cafe::1"), 100)) if msg.Type != MsgReply { t.Fatalf("Expected message type %d, got %d", MsgAdvertise, msg.Type) @@ -122,7 +126,8 @@ func TestMakeMsgReplyWithHttpClientArch(t *testing.T) { expectedServerId := []byte("serverid") transactionId := [3]byte{'1', '2', '3'} - msg := MakeMsgReply(transactionId, expectedServerId, expectedClientId, []byte("1234"), []byte{0x0, 0x10}) + msg := MakeMsgReply(transactionId, expectedServerId, expectedClientId, []byte("1234"), []byte{0x0, 0x10}, + NewRandomAddressPool(net.ParseIP("2001:db8:f00f:cafe::1"), net.ParseIP("2001:db8:f00f:cafe::1"), 100)) vendorClassOption := msg.Options[OptVendorClass] if vendorClassOption == nil { diff --git a/dhcp6/random_address_pool.go b/dhcp6/random_address_pool.go new file mode 100644 index 0000000..db0e0b2 --- /dev/null +++ b/dhcp6/random_address_pool.go @@ -0,0 +1,148 @@ +package dhcp6 + +import ( + "net" + "math/rand" + "time" + "math/big" + "hash/fnv" +) + +type AssociationExpiration struct { + expiresAt time.Time + ia *IdentityAssociation +} + +type Fifo struct {q []interface{}} + +func newFifo() Fifo { + return Fifo{q: make([]interface{}, 0, 1000)} +} + +func (f *Fifo) Push(v interface{}) { + f.q = append(f.q, v) +} + +func (f *Fifo) Shift() interface{} { + var to_ret interface{} + to_ret, f.q = f.q[0], f.q[1:] + return to_ret +} + +func (f *Fifo) Size() int { + return len(f.q) +} + +func (f *Fifo) Peek() interface{} { + if len(f.q) == 0 { + return nil + } + return f.q[0] +} + +type RandomAddressPool struct { + poolStartAddress *big.Int + poolEndAddress *big.Int + identityAssociations map[uint64]*IdentityAssociation + usedIps map[uint64]struct{} + identityAssociationExpirations Fifo + preferredLifetime uint32 // in seconds + timeNow func() time.Time + lock chan int +} + +func NewRandomAddressPool(poolStartAddress, poolEndAddress net.IP, preferredLifetime uint32) *RandomAddressPool { + to_ret := &RandomAddressPool{} + to_ret.preferredLifetime = preferredLifetime + to_ret.poolStartAddress = big.NewInt(0) + to_ret.poolStartAddress.SetBytes(poolStartAddress) + to_ret.poolEndAddress = big.NewInt(0) + to_ret.poolEndAddress.SetBytes(poolEndAddress) + to_ret.identityAssociations = make(map[uint64]*IdentityAssociation) + to_ret.usedIps = make(map[uint64]struct{}) + to_ret.identityAssociationExpirations = newFifo() + to_ret.timeNow = func() time.Time { return time.Now() } + to_ret.lock = make(chan int, 1) + to_ret.lock <- 1 + + ticker := time.NewTicker(time.Second * 10).C + go func() { + for { + <- ticker + to_ret.ExpireIdentityAssociations() + } + }() + + return to_ret +} + +func (p *RandomAddressPool) ReserveAddress(clientId, interfaceId []byte) *IdentityAssociation { + <-p.lock + clientIdHash := p.calculateIaIdHash(clientId, interfaceId) + association, exists := p.identityAssociations[clientIdHash]; if exists { + p.lock <- 1 + return association + } + + for { + rng := rand.New(rand.NewSource(p.timeNow().UnixNano())) + // we assume that ip addresses adhere to high 64 bits for net and subnet ids, low 64 bits are for host id rule + hostOffset := rng.Uint64()%(p.poolEndAddress.Uint64() - p.poolStartAddress.Uint64() + 1) + newIp := big.NewInt(0).Add(p.poolStartAddress, big.NewInt(0).SetUint64(hostOffset)) + _, exists := p.usedIps[newIp.Uint64()]; if !exists { + timeNow := p.timeNow() + to_ret := &IdentityAssociation{clientId: clientId, + interfaceId: interfaceId, + ipAddress: newIp.Bytes(), + createdAt: timeNow, + t1: p.calculateT1(p.preferredLifetime), + t2: p.calculateT2(p.preferredLifetime) } + p.identityAssociations[clientIdHash] = to_ret + p.usedIps[newIp.Uint64()] = struct{}{} + p.identityAssociationExpirations.Push(&AssociationExpiration{expiresAt: p.calculateAssociationExpiration(timeNow, p.preferredLifetime), ia: to_ret}) + p.lock <- 1 + return to_ret + } + } + p.lock <- 1 + return nil +} + +func (p *RandomAddressPool) ReleaseAddress(clientId, interfaceId []byte, addr net.IP) { + <-p.lock + delete(p.identityAssociations, p.calculateIaIdHash(clientId, interfaceId)) + delete(p.usedIps, big.NewInt(0).SetBytes(addr).Uint64()) + p.lock <- 1 +} + +func (p *RandomAddressPool) ExpireIdentityAssociations() { + <-p.lock + for { + if p.identityAssociationExpirations.Size() < 1 { break } + expiration := p.identityAssociationExpirations.Peek().(*AssociationExpiration) + if p.timeNow().Before(expiration.expiresAt) { break } + p.identityAssociationExpirations.Shift() + delete(p.identityAssociations, p.calculateIaIdHash(expiration.ia.clientId, expiration.ia.interfaceId)) + delete(p.usedIps, big.NewInt(0).SetBytes(expiration.ia.ipAddress).Uint64()) + } + p.lock <- 1 +} + +func (p *RandomAddressPool) calculateT1(preferredLifetime uint32) uint32 { + return preferredLifetime / 2 +} + +func (p *RandomAddressPool) calculateT2(preferredLifetime uint32) uint32 { + return (preferredLifetime * 4)/5 +} + +func (p *RandomAddressPool) calculateAssociationExpiration(now time.Time, preferredLifetime uint32) time.Time { + return now.Add(time.Duration(p.preferredLifetime)*time.Second) +} + +func (p *RandomAddressPool) calculateIaIdHash(clientId, interfaceId []byte) uint64 { + h := fnv.New64a() + h.Write(clientId) + h.Write(interfaceId) + return h.Sum64() +} diff --git a/dhcp6/random_address_pool_test.go b/dhcp6/random_address_pool_test.go new file mode 100644 index 0000000..4b7f080 --- /dev/null +++ b/dhcp6/random_address_pool_test.go @@ -0,0 +1,133 @@ +package dhcp6 + +import ( + "testing" + "net" + "time" +) + +func TestReserveAddress(t *testing.T) { + expectedIp := net.ParseIP("2001:db8:f00f:cafe::1") + expectedClientId := []byte("client-id") + expectedIaId := []byte("interface-id") + expectedTime := time.Now() + expectedMaxLifetime := uint32(100) + + pool := NewRandomAddressPool(net.ParseIP("2001:db8:f00f:cafe::1"), net.ParseIP("2001:db8:f00f:cafe::1"), expectedMaxLifetime) + pool.timeNow = func() time.Time { return expectedTime } + ia := pool.ReserveAddress(expectedClientId, expectedIaId) + + if ia == nil { + t.Fatalf("Expected a non-nil identity association") + } + if string(ia.ipAddress) != string(expectedIp) { + t.Fatalf("Expected ip: %v, but got: %v", expectedIp, ia.ipAddress) + } + if string(ia.clientId) != string(expectedClientId) { + t.Fatalf("Expected client id: %v, but got: %v", expectedClientId, ia.clientId) + } + if string(ia.interfaceId) != string(expectedIaId) { + t.Fatalf("Expected interface id: %v, but got: %v", expectedIaId, ia.interfaceId) + } + if ia.createdAt != expectedTime { + t.Fatalf("Expected creation time: %v, but got: %v", expectedTime, ia.createdAt) + } + if ia.createdAt != expectedTime { + t.Fatalf("Expected creation time: %v, but got: %v", expectedTime, ia.createdAt) + } + expectedT1 := pool.calculateT1(expectedMaxLifetime); if ia.t1 != expectedT1 { + t.Fatalf("Expected creation t1: %v, but got: %v", expectedT1, ia.t1) + } + expectedT2 := pool.calculateT2(expectedMaxLifetime); if ia.t2 != expectedT2 { + t.Fatalf("Expected creation t2: %v, but got: %v", expectedT2, ia.t2) + } +} + +func TestReserveAddressUpdatesAddressPool(t *testing.T) { + expectedClientId := []byte("client-id") + expectedIaId := []byte("interface-id") + expectedTime := time.Now() + expectedMaxLifetime := uint32(100) + + pool := NewRandomAddressPool(net.ParseIP("2001:db8:f00f:cafe::1"), net.ParseIP("2001:db8:f00f:cafe::1"), expectedMaxLifetime) + pool.timeNow = func() time.Time { return expectedTime } + pool.ReserveAddress(expectedClientId, expectedIaId) + expectedIdx := pool.calculateIaIdHash(expectedClientId, expectedIaId) + + + a, exists := pool.identityAssociations[expectedIdx] + if !exists { + t.Fatalf("Expected to find identity association at %d but didn't", expectedIdx) + } + if string(a.clientId) != string(expectedClientId) || string(a.interfaceId) != string(expectedIaId) { + t.Fatalf("Expected ia association with client id %x and ia id %x, but got %x %x respectively", expectedClientId, expectedIaId, a.clientId, a.interfaceId) + } +} + +func TestReserveAddressKeepsTrackOfUsedAddresses(t *testing.T) { + expectedClientId := []byte("client-id") + expectedIaId := []byte("interface-id") + expectedTime := time.Now() + expectedMaxLifetime := uint32(100) + + pool := NewRandomAddressPool(net.ParseIP("2001:db8:f00f:cafe::1"), net.ParseIP("2001:db8:f00f:cafe::1"), expectedMaxLifetime) + pool.timeNow = func() time.Time { return expectedTime } + pool.ReserveAddress(expectedClientId, expectedIaId) + + _, exists := pool.usedIps[0x01]; if !exists { + t.Fatal("'2001:db8:f00f:cafe::1' should be marked as in use") + } +} + +func TestReserveAddressKeepsTrackOfAssociationExpiration(t *testing.T) { + expectedClientId := []byte("client-id") + expectedIaId := []byte("interface-id") + expectedTime := time.Now() + expectedMaxLifetime := uint32(100) + + pool := NewRandomAddressPool(net.ParseIP("2001:db8:f00f:cafe::1"), net.ParseIP("2001:db8:f00f:cafe::1"), expectedMaxLifetime) + pool.timeNow = func() time.Time { return expectedTime } + pool.ReserveAddress(expectedClientId, expectedIaId) + + expiration := pool.identityAssociationExpirations.Peek().(*AssociationExpiration) + if expiration == nil { + t.Fatal("Expected an identity association expiration, but got nil") + } + if expiration.expiresAt != pool.calculateAssociationExpiration(expectedTime, expectedMaxLifetime) { + t.Fatalf("Expected association to expire at %v, but got %v", + pool.calculateAssociationExpiration(expectedTime, expectedMaxLifetime), expiration.expiresAt) + } +} + +func TestReserveAddressReturnsExistingAssociation(t *testing.T) { + expectedClientId := []byte("client-id") + expectedIaId := []byte("interface-id") + expectedTime := time.Now() + expectedMaxLifetime := uint32(100) + + pool := NewRandomAddressPool(net.ParseIP("2001:db8:f00f:cafe::1"), net.ParseIP("2001:db8:f00f:cafe::1"), expectedMaxLifetime) + pool.timeNow = func() time.Time { return expectedTime } + firstAssociation := pool.ReserveAddress(expectedClientId, expectedIaId) + secondAssociation := pool.ReserveAddress(expectedClientId, expectedIaId) + + if string(firstAssociation.ipAddress) != string(secondAssociation.ipAddress) { + t.Fatal("Expected return of the same ip address on both invocations") + } +} + +func TestReleaseAddress(t *testing.T) { + expectedClientId := []byte("client-id") + expectedIaId := []byte("interface-id") + expectedTime := time.Now() + expectedMaxLifetime := uint32(100) + + pool := NewRandomAddressPool(net.ParseIP("2001:db8:f00f:cafe::1"), net.ParseIP("2001:db8:f00f:cafe::1"), expectedMaxLifetime) + pool.timeNow = func() time.Time { return expectedTime } + a := pool.ReserveAddress(expectedClientId, expectedIaId) + + pool.ReleaseAddress(expectedClientId, expectedIaId, a.ipAddress) + + _, exists := pool.identityAssociations[pool.calculateIaIdHash(expectedClientId, expectedIaId)]; if exists { + t.Fatalf("identity association for %v should've been removed, but is still available", a.ipAddress) + } +} diff --git a/pixiecore/dhcpv6.go b/pixiecore/dhcpv6.go index c553cf1..afcce3a 100644 --- a/pixiecore/dhcpv6.go +++ b/pixiecore/dhcpv6.go @@ -5,7 +5,7 @@ import ( "fmt" ) -func (s *ServerV6) serveDHCP(conn *dhcp6.Conn) error { +func (s *ServerV6) serveDHCP(conn *dhcp6.Conn, addressPool dhcp6.AddressPool) error { s.log("dhcpv6", "Waiting for packets...\n") for { pkt, src, err := conn.RecvDHCP() @@ -19,7 +19,7 @@ func (s *ServerV6) serveDHCP(conn *dhcp6.Conn) error { s.log("dhcpv6", fmt.Sprintf("Received (%d) packet (%d): %s\n", pkt.Type, pkt.TransactionID, pkt.Options.HumanReadable())) - response := pkt.BuildResponse(s.Duid) + response := pkt.BuildResponse(s.Duid, addressPool) marshalled_response, err := response.Marshal() if err != nil { s.log("dhcpv6", fmt.Sprintf("Error marshalling response: %s", response.Type, response.TransactionID, err)) diff --git a/pixiecore/pixicorev6.go b/pixiecore/pixicorev6.go index 918926d..2e77ea2 100644 --- a/pixiecore/pixicorev6.go +++ b/pixiecore/pixicorev6.go @@ -54,8 +54,10 @@ func (s *ServerV6) Serve() error { s.errs = make(chan error, 6) //s.debug("Init", "Starting Pixiecore goroutines") + addressPool := dhcp6.NewRandomAddressPool(net.ParseIP("2001:db8:f00f:cafe::10"), net.ParseIP("2001:db8:f00f:cafe::100"), 1800) + s.SetDUID(dhcp.SourceHardwareAddress()) - go func() { s.errs <- s.serveDHCP(dhcp) }() + go func() { s.errs <- s.serveDHCP(dhcp, addressPool) }() // Wait for either a fatal error, or Shutdown(). err = <-s.errs