diff --git a/Makefile b/Makefile index 5a0100b..47efb8c 100644 --- a/Makefile +++ b/Makefile @@ -10,12 +10,14 @@ VERSION=1.0.0 SRCS=\ haza.go \ - iface.go \ - server.go + pkt.go \ + server.go \ + sock.go CMD_SRCS=\ cmd/logger.go \ - cmd/main.go + cmd/main.go \ + cmd/signal.go all: $(NAME) diff --git a/cmd/main.go b/cmd/main.go index cbcde6f..0ad5eab 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -12,6 +12,7 @@ func main() { var logger *AppLogger var addr *net.UDPAddr var addr2 *net.UDPAddr + var addr3 *net.UDPAddr var err error // addr, err = net.ResolveUDPAddr("udp4", "0.0.0.0:10000") @@ -27,7 +28,13 @@ func main() { os.Exit(-1) } - logger = NewAppLogger("client", os.Stderr, haza.LOG_ALL) + addr3, err = net.ResolveUDPAddr("udp4", "192.168.1.189:10000") + if err != nil { + fmt.Printf("Failed to resolve address - %s\n", err.Error()) + os.Exit(-1) + } + + logger = NewAppLogger("server", os.Stderr, haza.LOG_ALL) ds, err = haza.NewDhcpServer(context.Background(), "haza-dhcpd", logger) if err != nil { @@ -49,8 +56,14 @@ if err != nil { fmt.Printf("Failed to add listgener for %v - %s\n", addr2, err.Error()) goto oops } +err = ds.AddListener4("enp1s0", addr3) +if err != nil { + fmt.Printf("Failed to add listgener for %v - %s\n", addr3, err.Error()) + goto oops +} ds.StartService(nil) + ds.StartExtService(&signal_handler{svc:ds}, nil) ds.WaitForTermination() os.Exit(0) diff --git a/cmd/signal.go b/cmd/signal.go new file mode 100644 index 0000000..d52628b --- /dev/null +++ b/cmd/signal.go @@ -0,0 +1,69 @@ +package main + +import "haza" +import "sync" +import "os" +import "os/signal" +import "syscall" + +type signal_handler struct { + svc haza.Service +} + +func (sh *signal_handler) RunTask(wg *sync.WaitGroup) { + var sighup_chan chan os.Signal + var sigterm_chan chan os.Signal + var sig os.Signal + + if wg != nil { + defer wg.Done() + } + + sighup_chan = make(chan os.Signal, 1) + sigterm_chan = make(chan os.Signal, 1) + + signal.Notify(sighup_chan, syscall.SIGHUP) + signal.Notify(sigterm_chan, syscall.SIGTERM, os.Interrupt) + +chan_loop: + for { + select { + case <-sighup_chan: + sh.svc.FixServices() + + case sig = <-sigterm_chan: + sh.svc.StopServices() + sh.svc.WriteLog ("", haza.LOG_INFO, "Received %s signal", sig) + break chan_loop + } + } + + //signal.Reset(syscall.SIGHUP) + //signal.Reset(syscall.SIGTERM) + signal.Stop(sighup_chan) + signal.Stop(sigterm_chan) +} + +func (sh *signal_handler) StartService(data interface{}) { + // this isn't actually used standalone.. + // if we are to implement it, it must use the wait group for signal handler itself + // however, this service is run through another service. + // sh.wg.Add(1) + // go sh.RunTask(&sh.wg) +} + +func (sh *signal_handler) StopServices() { + syscall.Kill(syscall.Getpid(), syscall.SIGTERM) // TODO: find a better to terminate the signal handler... +} + +func (sh *signal_handler) FixServices() { +} + +func (sh *signal_handler) WaitForTermination() { + // not implemented. see the comment in StartServices() + // sh.wg.Wait() +} + +func (sh *signal_handler) WriteLog(id string, level haza.LogLevel, fmt string, args ...interface{}) { + sh.svc.WriteLog(id, level, fmt, args...) +} diff --git a/haza.go b/haza.go index c6bcb4f..86c6009 100644 --- a/haza.go +++ b/haza.go @@ -1,5 +1,9 @@ package haza +import "bytes" +import "encoding/binary" +import "io" +import "net" import "sync" import "unsafe" @@ -75,3 +79,46 @@ func Ntoh32(v uint32) uint32 { if _is_big_endian { return v } return ((v >> 24) & 0xFF) | ((v >> 8) & 0xFF00) | ((v << 8) & 0xFF0000) | ((v << 24) & 0xFF000000) } + +// --------------------------------------------------------- +type ByteReader struct { + r *bytes.Reader +} + +func NewByteReader(b []byte) *ByteReader { + return &ByteReader{ r: bytes.NewReader(b) } +} + +func (br* ByteReader) ReadByte() (byte, error) { + return br.r.ReadByte() +} + +func (br* ByteReader) ReadUint16() (uint16, error) { + var v uint16 + var err error + err = binary.Read(br.r, binary.BigEndian, &v) + if err != nil { return 0, err } + return v, nil +} + +func (br* ByteReader) ReadUint32() (uint32, error) { + var v uint32 + var err error + err = binary.Read(br.r, binary.BigEndian, &v) + if err != nil { return 0, err } + return v, nil +} + +func (br* ByteReader) ReadIp4() (net.IP, error) { + var buf [4]byte + var err error + _, err = io.ReadFull(br.r, buf[:]) + if err != nil { return nil, err } + return net.IP(buf[:]), nil +} + +func (br *ByteReader) ReadAllBytes(buf []byte) error { + var err error + _, err = io.ReadFull(br.r, buf) + return err +} diff --git a/pkt.go b/pkt.go new file mode 100644 index 0000000..34b603d --- /dev/null +++ b/pkt.go @@ -0,0 +1,148 @@ +package haza + +import "net" + +type Dhcp4Op = uint8 +type Dhcp4Htype = uint8 +type Dhcp4Msg = uint8 + +const ( + _ Dhcp4Op = iota + DHCP4_OP_BOOTREQUEST = 1 + DHCP4_OP_BOOTREPLY = 2 +) + +const ( + _ Dhcp4Htype = iota + DHCP4_HTYPE_ETHER = 1 + DHCP4_HTYPE_IEEE802 = 6 + DHCP4_HTYPE_ARCNET = 7 + DHCP4_HTYPE_APPLETALK = 8 + DHCP4_HTYPE_HDLC = 17 + DHCP4_HTYPE_ATM = 19 + DHCP4_HTYPE_ARPSEC = 30 + DHCP4_HTYPE_IPSEC = 31 + DHCP4_HTYPE_INFINIBAND = 32 + DHCP4_HTYPE_PUREIP = 35 +) + +const ( + _ Dhcp4Msg = iota + + DHCP4_MSG_DISCOVER = 1 + DHCP4_MSG_OFFER = 2 + DHCP4_MSG_REQUEST = 3 + DHCP4_MSG_DECLINE = 4 + DHCP4_MSG_ACK = 5 + DHCP4_MSG_NAK = 6 + DHCP4_MSG_RELEASE = 7 + DHCP4_MSG_INFORM = 8 + + DHCP4_MSG_FORCE_RENEW = 9 + + DHCP4_MSG_LEASE_QUERY = 10 + DHCP4_MSG_LEASE_UNASSIGNED = 11 + DHCP4_MSG_LEASE_UNKNOWN = 12 + DHCP4_MSG_LEASE_ACTIVE = 13 + + DHCP4_MSG_BULK_LEASE_QUERY = 14 + DHCP4_MSG_LEASE_QUERY_DONE = 15 + DHCP4_MSG_ACTIVE_LEASE_QUERY = 16 + DHCP4_MSG_LEASE_QUERY_STATUS = 17 + DHCP4_MSG_TLS = 18 +) + +type Dhcp4Pkt struct { + Op Dhcp4Op + Htype Dhcp4Htype + Hlen uint8 // length of Chaddr + + Hops uint8 + Xid uint32 + Secs uint16 + Flags uint16 + + Ciaddr net.IP // uint32 + Yiaddr net.IP // uint32 + Siaddr net.IP // uint32 + Gwaddr net.IP // uint32 + Chaddr [16]byte + + Sname [64]byte // server host name + File [128]byte // boot file name + + // options are placed after the header. + // the first four bytes of the options compose a magic cookie + // 0x63 0x82 0x53 0x63 +} + +func (pkt *Dhcp4Pkt) Decode(b []byte) error { + // fill the packet with data from the bytes + var r *ByteReader + var u8 byte + var u16 uint16 + var u32 uint32 + var p Dhcp4Pkt + var err error + + r = NewByteReader(b) + + u8, err = r.ReadByte() + if err != nil { return err } + p.Op = Dhcp4Op(u8) + + u8, err = r.ReadByte() + if err != nil { return err } + p.Htype = Dhcp4Htype(u8) + + u8, err = r.ReadByte() + if err != nil { return err } + p.Hlen = u8 + + u8, err = r.ReadByte() + if err != nil { return err } + p.Hops = u8 + + u32, err = r.ReadUint32() + if err != nil { return err } + p.Xid = u32 + + u16, err = r.ReadUint16() + if err != nil { return err } + p.Secs = u16 + + u16, err = r.ReadUint16() + if err != nil { return err } + p.Flags = u16 + + p.Ciaddr, err = r.ReadIp4() + if err != nil { return err } + + p.Yiaddr, err = r.ReadIp4() + if err != nil { return err } + + p.Siaddr, err = r.ReadIp4() + if err != nil { return err } + + p.Gwaddr, err = r.ReadIp4() + if err != nil { return err } + + err = r.ReadAllBytes(p.Chaddr[:]) + if err != nil { return err } + + err = r.ReadAllBytes(p.Sname[:]) + if err != nil { return err } + + err = r.ReadAllBytes(p.File[:]) + if err != nil { return err } + +// magic +// options.. + + *pkt = p + return nil +} + +func (pkt *Dhcp4Pkt) Encode() []byte { + return nil +} diff --git a/server.go b/server.go index 04e9f5f..115569f 100644 --- a/server.go +++ b/server.go @@ -17,22 +17,26 @@ type DhcpServer struct { Ctx context.Context CtxCancel context.CancelFunc + wg sync.WaitGroup stop_req atomic.Bool + ext_mtx sync.Mutex + ext_svcs []Service + ext_closed bool + conns_mtx sync.Mutex - conns map[string]*Dhcp4Conn - conns_by_fd map[int]*Dhcp4Conn + conns map[string]DhcpConn + conns_by_fd map[int]DhcpConn p int // epoll efd int // eventfd - wg sync.WaitGroup log Logger } func finalize_dhcp_server(s *DhcpServer) { - var dc *Dhcp4Conn + var dc DhcpConn for _, dc = range s.conns_by_fd { @@ -42,9 +46,9 @@ func finalize_dhcp_server(s *DhcpServer) { // unix.EpollCtl(s.p, unix.EPOLL_CTL_DEL, fd, nil) //}) - unix.EpollCtl(s.p, unix.EPOLL_CTL_DEL, dc.fd, nil) - delete(s.conns_by_fd, dc.fd) - delete(s.conns, dc.iface) + unix.EpollCtl(s.p, unix.EPOLL_CTL_DEL, dc.Fd(), nil) + delete(s.conns_by_fd, dc.Fd()) + delete(s.conns, dc.Iface()) dc.Close() } @@ -92,54 +96,50 @@ func NewDhcpServer(ctx context.Context, name string, logger Logger) (*DhcpServer s = &DhcpServer{} s.Ctx, s.CtxCancel = context.WithCancel(ctx) s.name = name - s.log = logger + s.log = logger s.stop_req.Store(false) + s.ext_svcs = make([]Service, 0, 1) + s.ext_closed = false s.efd = efd s.p = p - s.conns = make(map[string]*Dhcp4Conn, 0) - s.conns_by_fd = make(map[int]*Dhcp4Conn, 0) + s.conns = make(map[string]DhcpConn, 0) + s.conns_by_fd = make(map[int]DhcpConn, 0) runtime.SetFinalizer(s, finalize_dhcp_server) return s, nil } func (s *DhcpServer) AddListener4(iface string, addr *net.UDPAddr) error { - var conn *Dhcp4Conn + var conn DhcpConn var ok bool var err error s.conns_mtx.Lock() + defer s.conns_mtx.Unlock() + conn, ok = s.conns[iface] if ok { err = conn.AddAddr(addr) - if err != nil { - s.conns_mtx.Unlock() - return err - } + if err != nil { return err } } else { conn, err = NewDhcp4Conn(iface, addr) - if err != nil { -fmt.Printf ("fAILED TO CRETEA %s\n", err.Error()) - s.conns_mtx.Unlock() - return err - } -fmt.Printf ("ADDING dhcp4... %d\n", conn.fd) - err = unix.EpollCtl(s.p, unix.EPOLL_CTL_ADD, conn.fd, + if err != nil { return err } +fmt.Printf ("ADDING dhcp4... %d\n", conn.Fd()) + err = unix.EpollCtl(s.p, unix.EPOLL_CTL_ADD, conn.Fd(), &unix.EpollEvent{ Events: unix.EPOLLIN | unix.EPOLLERR | unix.EPOLLHUP, - Fd: int32(conn.fd), + Fd: int32(conn.Fd()), }, ) if err != nil { - s.conns_mtx.Unlock() conn.Close() return err } + + s.conns[iface] = conn // this isn't needed??? or this must be a slice... + s.conns_by_fd[conn.Fd()] = conn } - s.conns[iface] = conn // this isn't needed??? or this must be a slice... - s.conns_by_fd[conn.fd] = conn - s.conns_mtx.Unlock() return nil } @@ -176,15 +176,28 @@ func (s *DhcpServer) StartService(data interface{}) { go s.RunTask(&s.wg) } +func (s *DhcpServer) StartExtService(svc Service, data interface{}) { + s.ext_mtx.Lock() + if s.ext_closed { + // don't start it if it's already closed + s.ext_mtx.Unlock() + return + } + s.ext_svcs = append(s.ext_svcs, svc) + s.ext_mtx.Unlock() + s.wg.Add(1) + go svc.RunTask(&s.wg) +} + func (s *DhcpServer) StopServices() { -// var ext_svc Service + var ext_svc Service s.ReqStop() -// s.ext_mtx.Lock() -// for _, ext_svc = range s.ext_svcs { -// ext_svc.StopServices() -// } -// s.ext_closed = true -// s.ext_mtx.Unlock() + s.ext_mtx.Lock() + for _, ext_svc = range s.ext_svcs { + ext_svc.StopServices() + } + s.ext_closed = true + s.ext_mtx.Unlock() } func (s *DhcpServer) FixServices() { @@ -200,45 +213,47 @@ func (s *DhcpServer) WriteLog(id string, level LogLevel, fmtstr string, args ... s.log.Write(id, level, fmtstr, args...) } - - - - func get_offsets(frame []byte) (int, int, int, int, int, error) { + var eth_off int + var ip_off int + var ihl int + var udp_off int + var udp_len int + var dhcp_off int + var dhcp_end int + if len(frame) < 14 { return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for Ethernet header") } - ethOff := 0 - ipOff := ethOff + 14 + eth_off = 0 + ip_off = eth_off + 14 - if len(frame) < ipOff+20 { + if len(frame) < ip_off + 20 { return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for IPv4 header") } - // First byte of IPv4 header: 4 bits version, 4 bits IHL - ihl := int(frame[ipOff] & 0x0F) * 4 - if len(frame) < ipOff+ihl { - return 0, 0, 0, 0, 0, fmt.Errorf("invalid IHL: %d", ihl) + ihl = int(frame[ip_off] & 0x0F) * 4 // the lower first 4 bits of the first bytes * 4 + if len(frame) < ip_off + ihl { + return 0, 0, 0, 0, 0, fmt.Errorf("invalid IHL %d", ihl) } - udpOff := ipOff + ihl - if len(frame) < udpOff+8 { + udp_off = ip_off + ihl + if len(frame) < udp_off + 8 { return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for UDP header") } - // UDP length field (2 bytes at offset 4–5 of UDP header) - udpLen := int(frame[udpOff+4])<<8 | int(frame[udpOff+5]) - if udpLen < 8 { - return 0, 0, 0, 0, 0, fmt.Errorf("invalid UDP length: %d", udpLen) + udp_len = int(frame[udp_off + 4]) << 8 | int(frame[udp_off + 5]) + if udp_len < 8 { + return 0, 0, 0, 0, 0, fmt.Errorf("invalid UDP length %d", udp_len) } - if len(frame) < udpOff+udpLen { - return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for UDP length %d", udpLen) + if len(frame) < udp_off + udp_len { + return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for UDP length %d", udp_len) } - dhcpOff := udpOff + 8 - dhcpEnd := udpOff + udpLen + dhcp_off = udp_off + 8 + dhcp_end = udp_off + udp_len - return ethOff, ipOff, udpOff, dhcpOff, dhcpEnd, nil + return eth_off, ip_off, udp_off, dhcp_off, dhcp_end, nil } func (s *DhcpServer) run_recv_loop() error { @@ -253,13 +268,13 @@ epoll_loop: nevts, err = unix.EpollWait(s.p, evts[:], -1) if err != nil { if errors.Is(err, unix.EINTR) { continue } -fmt.Printf ("epoll error - %s\n", err.Error()) - break + s.log.Write("", LOG_ERROR, "event multiplexer error - %s", err.Error()) + break epoll_loop } for i = 0; i < nevts ; i++ { if (evts[i].Events & (unix.EPOLLHUP | unix.EPOLLERR)) != 0 { -fmt.Printf ("epoll something not right...\n") + s.log.Write("", LOG_ERROR, "event HUP or ERR detected on fd %d - %s", i) break epoll_loop } @@ -286,12 +301,16 @@ fmt.Printf ("epoll something not right...\n") if string(buf[spos:epos]) == "quit\n" { break epoll_loop } + + + var pkt Dhcp4Pkt + pkt.Decode(buf[spos:epos]) + fmt.Printf("%+v\n", pkt) } } if err != nil { - // TODO: logging - fmt.Printf("ERROR... %s\n", err.Error()) + s.log.Write("", LOG_ERROR, "Failed to read packet on fd %d - %s", i, err.Error()) break epoll_loop } } diff --git a/iface.go b/sock.go similarity index 95% rename from iface.go rename to sock.go index f33a5e5..b2b2561 100644 --- a/iface.go +++ b/sock.go @@ -8,6 +8,19 @@ import "sync" import "golang.org/x/net/bpf" import "golang.org/x/sys/unix" + +type DhcpConn interface { + Version() int + Fd() int + Iface() string + Close() + AddAddr(addr *net.UDPAddr) error + //DelAddr(addr *net.UDPAddr) error +} + + +// TODO: add dhcp6 sockets + type Dhcp4UdpSock struct { addr *net.UDPAddr fd int @@ -131,6 +144,18 @@ func (d *Dhcp4Conn) Close() { } } +func (d *Dhcp4Conn) Version() int { + return 4 +} + +func (d *Dhcp4Conn) Fd() int { + return d.fd +} + +func (d *Dhcp4Conn) Iface() string { + return d.iface +} + func (d *Dhcp4Conn) ForEachUdpSockFd(handler func (fd int)) { var us *Dhcp4UdpSock @@ -388,6 +413,12 @@ func open_udp4_socket(addr *net.UDPAddr) (int, error) { err = attach_discard_filter(fd) if err != nil { goto oops } + err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1) + if err != nil { goto oops } + + // i won't be setting unix.SO_REUSEPORT as i don't want multple programs + // to run on the same port. + err = unix.Bind(fd, &unix.SockaddrInet4{ Addr: [4]byte(ip[0:4]), Port: addr.Port}) if err != nil { goto oops } diff --git a/iface_test.go b/sock_test.go similarity index 100% rename from iface_test.go rename to sock_test.go