added some dhcp4 packet functions and generic byte reading functions

This commit is contained in:
2025-09-17 19:31:37 +09:00
parent 67b3e9727b
commit 90365bfdd4
8 changed files with 394 additions and 65 deletions

141
server.go
View File

@ -17,22 +17,26 @@ type DhcpServer struct {
Ctx context.Context
CtxCancel context.CancelFunc
wg sync.WaitGroup
stop_req atomic.Bool
ext_mtx sync.Mutex
ext_svcs []Service
ext_closed bool
conns_mtx sync.Mutex
conns map[string]*Dhcp4Conn
conns_by_fd map[int]*Dhcp4Conn
conns map[string]DhcpConn
conns_by_fd map[int]DhcpConn
p int // epoll
efd int // eventfd
wg sync.WaitGroup
log Logger
}
func finalize_dhcp_server(s *DhcpServer) {
var dc *Dhcp4Conn
var dc DhcpConn
for _, dc = range s.conns_by_fd {
@ -42,9 +46,9 @@ func finalize_dhcp_server(s *DhcpServer) {
// unix.EpollCtl(s.p, unix.EPOLL_CTL_DEL, fd, nil)
//})
unix.EpollCtl(s.p, unix.EPOLL_CTL_DEL, dc.fd, nil)
delete(s.conns_by_fd, dc.fd)
delete(s.conns, dc.iface)
unix.EpollCtl(s.p, unix.EPOLL_CTL_DEL, dc.Fd(), nil)
delete(s.conns_by_fd, dc.Fd())
delete(s.conns, dc.Iface())
dc.Close()
}
@ -92,54 +96,50 @@ func NewDhcpServer(ctx context.Context, name string, logger Logger) (*DhcpServer
s = &DhcpServer{}
s.Ctx, s.CtxCancel = context.WithCancel(ctx)
s.name = name
s.log = logger
s.log = logger
s.stop_req.Store(false)
s.ext_svcs = make([]Service, 0, 1)
s.ext_closed = false
s.efd = efd
s.p = p
s.conns = make(map[string]*Dhcp4Conn, 0)
s.conns_by_fd = make(map[int]*Dhcp4Conn, 0)
s.conns = make(map[string]DhcpConn, 0)
s.conns_by_fd = make(map[int]DhcpConn, 0)
runtime.SetFinalizer(s, finalize_dhcp_server)
return s, nil
}
func (s *DhcpServer) AddListener4(iface string, addr *net.UDPAddr) error {
var conn *Dhcp4Conn
var conn DhcpConn
var ok bool
var err error
s.conns_mtx.Lock()
defer s.conns_mtx.Unlock()
conn, ok = s.conns[iface]
if ok {
err = conn.AddAddr(addr)
if err != nil {
s.conns_mtx.Unlock()
return err
}
if err != nil { return err }
} else {
conn, err = NewDhcp4Conn(iface, addr)
if err != nil {
fmt.Printf ("fAILED TO CRETEA %s\n", err.Error())
s.conns_mtx.Unlock()
return err
}
fmt.Printf ("ADDING dhcp4... %d\n", conn.fd)
err = unix.EpollCtl(s.p, unix.EPOLL_CTL_ADD, conn.fd,
if err != nil { return err }
fmt.Printf ("ADDING dhcp4... %d\n", conn.Fd())
err = unix.EpollCtl(s.p, unix.EPOLL_CTL_ADD, conn.Fd(),
&unix.EpollEvent{
Events: unix.EPOLLIN | unix.EPOLLERR | unix.EPOLLHUP,
Fd: int32(conn.fd),
Fd: int32(conn.Fd()),
},
)
if err != nil {
s.conns_mtx.Unlock()
conn.Close()
return err
}
s.conns[iface] = conn // this isn't needed??? or this must be a slice...
s.conns_by_fd[conn.Fd()] = conn
}
s.conns[iface] = conn // this isn't needed??? or this must be a slice...
s.conns_by_fd[conn.fd] = conn
s.conns_mtx.Unlock()
return nil
}
@ -176,15 +176,28 @@ func (s *DhcpServer) StartService(data interface{}) {
go s.RunTask(&s.wg)
}
func (s *DhcpServer) StartExtService(svc Service, data interface{}) {
s.ext_mtx.Lock()
if s.ext_closed {
// don't start it if it's already closed
s.ext_mtx.Unlock()
return
}
s.ext_svcs = append(s.ext_svcs, svc)
s.ext_mtx.Unlock()
s.wg.Add(1)
go svc.RunTask(&s.wg)
}
func (s *DhcpServer) StopServices() {
// var ext_svc Service
var ext_svc Service
s.ReqStop()
// s.ext_mtx.Lock()
// for _, ext_svc = range s.ext_svcs {
// ext_svc.StopServices()
// }
// s.ext_closed = true
// s.ext_mtx.Unlock()
s.ext_mtx.Lock()
for _, ext_svc = range s.ext_svcs {
ext_svc.StopServices()
}
s.ext_closed = true
s.ext_mtx.Unlock()
}
func (s *DhcpServer) FixServices() {
@ -200,45 +213,47 @@ func (s *DhcpServer) WriteLog(id string, level LogLevel, fmtstr string, args ...
s.log.Write(id, level, fmtstr, args...)
}
func get_offsets(frame []byte) (int, int, int, int, int, error) {
var eth_off int
var ip_off int
var ihl int
var udp_off int
var udp_len int
var dhcp_off int
var dhcp_end int
if len(frame) < 14 {
return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for Ethernet header")
}
ethOff := 0
ipOff := ethOff + 14
eth_off = 0
ip_off = eth_off + 14
if len(frame) < ipOff+20 {
if len(frame) < ip_off + 20 {
return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for IPv4 header")
}
// First byte of IPv4 header: 4 bits version, 4 bits IHL
ihl := int(frame[ipOff] & 0x0F) * 4
if len(frame) < ipOff+ihl {
return 0, 0, 0, 0, 0, fmt.Errorf("invalid IHL: %d", ihl)
ihl = int(frame[ip_off] & 0x0F) * 4 // the lower first 4 bits of the first bytes * 4
if len(frame) < ip_off + ihl {
return 0, 0, 0, 0, 0, fmt.Errorf("invalid IHL %d", ihl)
}
udpOff := ipOff + ihl
if len(frame) < udpOff+8 {
udp_off = ip_off + ihl
if len(frame) < udp_off + 8 {
return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for UDP header")
}
// UDP length field (2 bytes at offset 45 of UDP header)
udpLen := int(frame[udpOff+4])<<8 | int(frame[udpOff+5])
if udpLen < 8 {
return 0, 0, 0, 0, 0, fmt.Errorf("invalid UDP length: %d", udpLen)
udp_len = int(frame[udp_off + 4]) << 8 | int(frame[udp_off + 5])
if udp_len < 8 {
return 0, 0, 0, 0, 0, fmt.Errorf("invalid UDP length %d", udp_len)
}
if len(frame) < udpOff+udpLen {
return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for UDP length %d", udpLen)
if len(frame) < udp_off + udp_len {
return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for UDP length %d", udp_len)
}
dhcpOff := udpOff + 8
dhcpEnd := udpOff + udpLen
dhcp_off = udp_off + 8
dhcp_end = udp_off + udp_len
return ethOff, ipOff, udpOff, dhcpOff, dhcpEnd, nil
return eth_off, ip_off, udp_off, dhcp_off, dhcp_end, nil
}
func (s *DhcpServer) run_recv_loop() error {
@ -253,13 +268,13 @@ epoll_loop:
nevts, err = unix.EpollWait(s.p, evts[:], -1)
if err != nil {
if errors.Is(err, unix.EINTR) { continue }
fmt.Printf ("epoll error - %s\n", err.Error())
break
s.log.Write("", LOG_ERROR, "event multiplexer error - %s", err.Error())
break epoll_loop
}
for i = 0; i < nevts ; i++ {
if (evts[i].Events & (unix.EPOLLHUP | unix.EPOLLERR)) != 0 {
fmt.Printf ("epoll something not right...\n")
s.log.Write("", LOG_ERROR, "event HUP or ERR detected on fd %d - %s", i)
break epoll_loop
}
@ -286,12 +301,16 @@ fmt.Printf ("epoll something not right...\n")
if string(buf[spos:epos]) == "quit\n" {
break epoll_loop
}
var pkt Dhcp4Pkt
pkt.Decode(buf[spos:epos])
fmt.Printf("%+v\n", pkt)
}
}
if err != nil {
// TODO: logging
fmt.Printf("ERROR... %s\n", err.Error())
s.log.Write("", LOG_ERROR, "Failed to read packet on fd %d - %s", i, err.Error())
break epoll_loop
}
}