Files
haza/server.go
2025-09-15 20:06:21 +09:00

305 lines
6.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package haza
import "context"
import "errors"
import "fmt"
import "net"
import "runtime"
import "sync"
import "sync/atomic"
import "unsafe"
import "golang.org/x/sys/unix"
type DhcpServer struct {
Named
Ctx context.Context
CtxCancel context.CancelFunc
stop_req atomic.Bool
conns_mtx sync.Mutex
conns map[string]*Dhcp4Conn
conns_by_fd map[int]*Dhcp4Conn
p int // epoll
efd int // eventfd
wg sync.WaitGroup
log Logger
}
func finalize_dhcp_server(s *DhcpServer) {
var dc *Dhcp4Conn
for _, dc = range s.conns_by_fd {
// the internal udp sockets wasn't addded to epoll.
// something liek the following isn't needed.
//dc.ForEachUdpSockFd(func (fd int) {
// unix.EpollCtl(s.p, unix.EPOLL_CTL_DEL, fd, nil)
//})
unix.EpollCtl(s.p, unix.EPOLL_CTL_DEL, dc.fd, nil)
delete(s.conns_by_fd, dc.fd)
delete(s.conns, dc.iface)
dc.Close()
}
if s.efd >= 0 {
if s.p >= 0 { unix.EpollCtl(s.p, unix.EPOLL_CTL_DEL, s.efd, nil) }
unix.Close(s.efd)
s.efd = -1
}
if s.p >= 0 {
unix.Close(s.p)
s.p = -1
}
}
func NewDhcpServer(ctx context.Context, name string, logger Logger) (*DhcpServer, error) {
var s *DhcpServer
var p int
var efd int
var err error
p, err = unix.EpollCreate1(0)
if err != nil {
return nil, fmt.Errorf("unable to create epoll - %s", err.Error())
}
efd, err = unix.Eventfd(0, unix.EFD_CLOEXEC)
if err != nil {
unix.Close(p)
return nil, err
}
err = unix.EpollCtl(p, unix.EPOLL_CTL_ADD, efd,
&unix.EpollEvent{
Events: unix.EPOLLIN | unix.EPOLLERR | unix.EPOLLHUP,
Fd: int32(efd),
},
)
if err != nil {
unix.Close(efd)
unix.Close(p)
return nil, err
}
s = &DhcpServer{}
s.Ctx, s.CtxCancel = context.WithCancel(ctx)
s.name = name
s.log = logger
s.stop_req.Store(false)
s.efd = efd
s.p = p
s.conns = make(map[string]*Dhcp4Conn, 0)
s.conns_by_fd = make(map[int]*Dhcp4Conn, 0)
runtime.SetFinalizer(s, finalize_dhcp_server)
return s, nil
}
func (s *DhcpServer) AddListener4(iface string, addr *net.UDPAddr) error {
var conn *Dhcp4Conn
var ok bool
var err error
s.conns_mtx.Lock()
conn, ok = s.conns[iface]
if ok {
err = conn.AddAddr(addr)
if err != nil {
s.conns_mtx.Unlock()
return err
}
} else {
conn, err = NewDhcp4Conn(iface, addr)
if err != nil {
fmt.Printf ("fAILED TO CRETEA %s\n", err.Error())
s.conns_mtx.Unlock()
return err
}
fmt.Printf ("ADDING dhcp4... %d\n", conn.fd)
err = unix.EpollCtl(s.p, unix.EPOLL_CTL_ADD, conn.fd,
&unix.EpollEvent{
Events: unix.EPOLLIN | unix.EPOLLERR | unix.EPOLLHUP,
Fd: int32(conn.fd),
},
)
if err != nil {
s.conns_mtx.Unlock()
conn.Close()
return err
}
}
s.conns[iface] = conn // this isn't needed??? or this must be a slice...
s.conns_by_fd[conn.fd] = conn
s.conns_mtx.Unlock()
return nil
}
func (s *DhcpServer) RunTask(wg *sync.WaitGroup) {
// var l_wg sync.WaitGroup
defer wg.Done()
/*
for i = 0; i < 10; i++ { // TODO: configured number of handlersz...
l_wg.Add(1)
go s.run_handlers(&l_wg)
}
*/
fmt.Printf ("beginning of recv loop\n")
s.run_recv_loop()
fmt.Printf ("end of recv loop\n")
// l_wg.Wait()
finalize_dhcp_server(s)
}
func (s *DhcpServer) ReqStop() {
if s.stop_req.CompareAndSwap(false, true) {
var v uint64
// eventfd needs an 8-byte integer.
v = 1
unix.Write(s.efd, (*[8]byte)(unsafe.Pointer(&v))[:])
}
}
func (s *DhcpServer) StartService(data interface{}) {
s.wg.Add(1)
go s.RunTask(&s.wg)
}
func (s *DhcpServer) StopServices() {
// var ext_svc Service
s.ReqStop()
// s.ext_mtx.Lock()
// for _, ext_svc = range s.ext_svcs {
// ext_svc.StopServices()
// }
// s.ext_closed = true
// s.ext_mtx.Unlock()
}
func (s *DhcpServer) FixServices() {
s.log.Rotate()
}
func (s *DhcpServer) WaitForTermination() {
s.wg.Wait()
s.log.Write("", LOG_INFO, "End of service")
}
func (s *DhcpServer) WriteLog(id string, level LogLevel, fmtstr string, args ...interface{}) {
s.log.Write(id, level, fmtstr, args...)
}
func get_offsets(frame []byte) (int, int, int, int, int, error) {
if len(frame) < 14 {
return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for Ethernet header")
}
ethOff := 0
ipOff := ethOff + 14
if len(frame) < ipOff+20 {
return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for IPv4 header")
}
// First byte of IPv4 header: 4 bits version, 4 bits IHL
ihl := int(frame[ipOff] & 0x0F) * 4
if len(frame) < ipOff+ihl {
return 0, 0, 0, 0, 0, fmt.Errorf("invalid IHL: %d", ihl)
}
udpOff := ipOff + ihl
if len(frame) < udpOff+8 {
return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for UDP header")
}
// UDP length field (2 bytes at offset 45 of UDP header)
udpLen := int(frame[udpOff+4])<<8 | int(frame[udpOff+5])
if udpLen < 8 {
return 0, 0, 0, 0, 0, fmt.Errorf("invalid UDP length: %d", udpLen)
}
if len(frame) < udpOff+udpLen {
return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for UDP length %d", udpLen)
}
dhcpOff := udpOff + 8
dhcpEnd := udpOff + udpLen
return ethOff, ipOff, udpOff, dhcpOff, dhcpEnd, nil
}
func (s *DhcpServer) run_recv_loop() error {
var buf [1500]byte
var nevts int
var i int
var evts [128]unix.EpollEvent
var err error
epoll_loop:
for {
nevts, err = unix.EpollWait(s.p, evts[:], -1)
if err != nil {
if errors.Is(err, unix.EINTR) { continue }
fmt.Printf ("epoll error - %s\n", err.Error())
break
}
for i = 0; i < nevts ; i++ {
if (evts[i].Events & (unix.EPOLLHUP | unix.EPOLLERR)) != 0 {
fmt.Printf ("epoll something not right...\n")
break epoll_loop
}
if int(evts[i].Fd) == s.efd {
if ((evts[i].Events & unix.EPOLLIN) != 0) {
unix.Read(s.efd, buf[:])
}
break epoll_loop
} else {
if (evts[i].Events & unix.EPOLLIN) != 0 {
var n int
var from unix.Sockaddr
n, from, err = unix.Recvfrom(int(evts[i].Fd), buf[:], 0)
fmt.Printf("%d %v [% x] %v\n", n, sockaddrToString(from), buf[:n], err)
if n > 0 {
var spos int
var epos int
_, _, _, spos, epos, err = get_offsets(buf[:n])
if err != nil {
fmt.Printf ("PACKET -> %s\n", err.Error())
} else {
//fmt.Printf (">>%d [%s]<<\n", epos - spos, string(buf[spos: epos]))
if string(buf[spos:epos]) == "quit\n" {
break epoll_loop
}
}
}
if err != nil {
// TODO: logging
fmt.Printf("ERROR... %s\n", err.Error())
break epoll_loop
}
}
}
}
}
// if there are subtasks wait here for termination and close the dockets
return nil
}