package haza import "context" import "errors" import "fmt" import "net" import "runtime" import "sync" import "sync/atomic" import "unsafe" import "golang.org/x/sys/unix" type DhcpServer struct { Named Ctx context.Context CtxCancel context.CancelFunc wg sync.WaitGroup stop_req atomic.Bool ext_mtx sync.Mutex ext_svcs []Service ext_closed bool conns_mtx sync.Mutex conns map[string]DhcpConn conns_by_fd map[int]DhcpConn p int // epoll efd int // eventfd log Logger } func finalize_dhcp_server(s *DhcpServer) { var dc DhcpConn for _, dc = range s.conns_by_fd { // the internal udp sockets wasn't addded to epoll. // something liek the following isn't needed. //dc.ForEachUdpSockFd(func (fd int) { // unix.EpollCtl(s.p, unix.EPOLL_CTL_DEL, fd, nil) //}) unix.EpollCtl(s.p, unix.EPOLL_CTL_DEL, dc.Fd(), nil) delete(s.conns_by_fd, dc.Fd()) delete(s.conns, dc.Iface()) dc.Close() } if s.efd >= 0 { if s.p >= 0 { unix.EpollCtl(s.p, unix.EPOLL_CTL_DEL, s.efd, nil) } unix.Close(s.efd) s.efd = -1 } if s.p >= 0 { unix.Close(s.p) s.p = -1 } } func NewDhcpServer(ctx context.Context, name string, logger Logger) (*DhcpServer, error) { var s *DhcpServer var p int var efd int var err error p, err = unix.EpollCreate1(0) if err != nil { return nil, fmt.Errorf("unable to create epoll - %s", err.Error()) } efd, err = unix.Eventfd(0, unix.EFD_CLOEXEC) if err != nil { unix.Close(p) return nil, err } err = unix.EpollCtl(p, unix.EPOLL_CTL_ADD, efd, &unix.EpollEvent{ Events: unix.EPOLLIN | unix.EPOLLERR | unix.EPOLLHUP, Fd: int32(efd), }, ) if err != nil { unix.Close(efd) unix.Close(p) return nil, err } s = &DhcpServer{} s.Ctx, s.CtxCancel = context.WithCancel(ctx) s.name = name s.log = logger s.stop_req.Store(false) s.ext_svcs = make([]Service, 0, 1) s.ext_closed = false s.efd = efd s.p = p s.conns = make(map[string]DhcpConn, 0) s.conns_by_fd = make(map[int]DhcpConn, 0) runtime.SetFinalizer(s, finalize_dhcp_server) return s, nil } func (s *DhcpServer) AddListener4(iface string, addr *net.UDPAddr) error { var conn DhcpConn var ok bool var err error s.conns_mtx.Lock() defer s.conns_mtx.Unlock() conn, ok = s.conns[iface] if ok { err = conn.AddAddr(addr) if err != nil { return err } } else { conn, err = NewDhcp4Conn(iface, addr) if err != nil { return err } fmt.Printf ("ADDING dhcp4... %d\n", conn.Fd()) err = unix.EpollCtl(s.p, unix.EPOLL_CTL_ADD, conn.Fd(), &unix.EpollEvent{ Events: unix.EPOLLIN | unix.EPOLLERR | unix.EPOLLHUP, Fd: int32(conn.Fd()), }, ) if err != nil { conn.Close() return err } s.conns[iface] = conn // this isn't needed??? or this must be a slice... s.conns_by_fd[conn.Fd()] = conn } return nil } func (s *DhcpServer) RunTask(wg *sync.WaitGroup) { // var l_wg sync.WaitGroup defer wg.Done() /* for i = 0; i < 10; i++ { // TODO: configured number of handlersz... l_wg.Add(1) go s.run_handlers(&l_wg) } */ fmt.Printf ("beginning of recv loop\n") s.run_recv_loop() fmt.Printf ("end of recv loop\n") // l_wg.Wait() finalize_dhcp_server(s) } func (s *DhcpServer) ReqStop() { if s.stop_req.CompareAndSwap(false, true) { var v uint64 // eventfd needs an 8-byte integer. v = 1 unix.Write(s.efd, (*[8]byte)(unsafe.Pointer(&v))[:]) } } func (s *DhcpServer) StartService(data interface{}) { s.wg.Add(1) go s.RunTask(&s.wg) } func (s *DhcpServer) StartExtService(svc Service, data interface{}) { s.ext_mtx.Lock() if s.ext_closed { // don't start it if it's already closed s.ext_mtx.Unlock() return } s.ext_svcs = append(s.ext_svcs, svc) s.ext_mtx.Unlock() s.wg.Add(1) go svc.RunTask(&s.wg) } func (s *DhcpServer) StopServices() { var ext_svc Service s.ReqStop() s.ext_mtx.Lock() for _, ext_svc = range s.ext_svcs { ext_svc.StopServices() } s.ext_closed = true s.ext_mtx.Unlock() } func (s *DhcpServer) FixServices() { s.log.Rotate() } func (s *DhcpServer) WaitForTermination() { s.wg.Wait() s.log.Write("", LOG_INFO, "End of service") } func (s *DhcpServer) WriteLog(id string, level LogLevel, fmtstr string, args ...interface{}) { s.log.Write(id, level, fmtstr, args...) } func get_offsets(frame []byte) (int, int, int, int, int, error) { var eth_off int var ip_off int var ihl int var udp_off int var udp_len int var dhcp_off int var dhcp_end int if len(frame) < 14 { return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for Ethernet header") } eth_off = 0 ip_off = eth_off + 14 if len(frame) < ip_off + 20 { return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for IPv4 header") } ihl = int(frame[ip_off] & 0x0F) * 4 // the lower first 4 bits of the first bytes * 4 if len(frame) < ip_off + ihl { return 0, 0, 0, 0, 0, fmt.Errorf("invalid IHL %d", ihl) } udp_off = ip_off + ihl if len(frame) < udp_off + 8 { return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for UDP header") } udp_len = int(frame[udp_off + 4]) << 8 | int(frame[udp_off + 5]) if udp_len < 8 { return 0, 0, 0, 0, 0, fmt.Errorf("invalid UDP length %d", udp_len) } if len(frame) < udp_off + udp_len { return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for UDP length %d", udp_len) } dhcp_off = udp_off + 8 dhcp_end = udp_off + udp_len return eth_off, ip_off, udp_off, dhcp_off, dhcp_end, nil } func (s *DhcpServer) run_recv_loop() error { var buf [1500]byte var nevts int var i int var evts [128]unix.EpollEvent var err error epoll_loop: for { nevts, err = unix.EpollWait(s.p, evts[:], -1) if err != nil { if errors.Is(err, unix.EINTR) { continue } s.log.Write("", LOG_ERROR, "event multiplexer error - %s", err.Error()) break epoll_loop } for i = 0; i < nevts ; i++ { if (evts[i].Events & (unix.EPOLLHUP | unix.EPOLLERR)) != 0 { s.log.Write("", LOG_ERROR, "event HUP or ERR detected on fd %d - %s", i) break epoll_loop } if int(evts[i].Fd) == s.efd { if ((evts[i].Events & unix.EPOLLIN) != 0) { unix.Read(s.efd, buf[:]) } break epoll_loop } else { if (evts[i].Events & unix.EPOLLIN) != 0 { var n int var from unix.Sockaddr n, from, err = unix.Recvfrom(int(evts[i].Fd), buf[:], 0) fmt.Printf("%d %v [% x] %v\n", n, sockaddrToString(from), buf[:n], err) if n > 0 { var spos int var epos int _, _, _, spos, epos, err = get_offsets(buf[:n]) if err != nil { fmt.Printf ("PACKET -> %s\n", err.Error()) } else { //fmt.Printf (">>%d [%s]<<\n", epos - spos, string(buf[spos: epos])) if string(buf[spos:epos]) == "quit\n" { break epoll_loop } var pkt Dhcp4Pkt pkt.Decode(buf[spos:epos]) fmt.Printf("%+v\n", pkt) } } if err != nil { s.log.Write("", LOG_ERROR, "Failed to read packet on fd %d - %s", i, err.Error()) break epoll_loop } } } } } // if there are subtasks wait here for termination and close the dockets return nil }