commit 67b3e9727b92153458e1e46ce1aedf6ad25e5d23 Author: hyung-hwan Date: Mon Sep 15 20:06:21 2025 +0900 some initial code diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..5a0100b --- /dev/null +++ b/Makefile @@ -0,0 +1,46 @@ +# make +# make GOARCH=386 +# make GOARCH=amd64 +# make GOOS=linux GOARCH=mips +# +# 'go tool dist list' for available os and architextures + +NAME=haza +VERSION=1.0.0 + +SRCS=\ + haza.go \ + iface.go \ + server.go + +CMD_SRCS=\ + cmd/logger.go \ + cmd/main.go + +all: $(NAME) + +$(NAME): $(DATA) $(SRCS) $(CMD_DATA) $(CMD_SRCS) + CGO_ENABLED=0 go build -x -ldflags "-X 'main.HAZA_NAME=$(NAME)' -X 'main.HAZA_VERSION=$(VERSION)'" -o $@ $(CMD_SRCS) + ##CGO_ENABLED=1 go build -x -ldflags "-X 'main.HAZA_NAME=$(NAME)' -X 'main.HAZA_VERSION=$(VERSION)'" -o $@ $(CMD_SRCS) + ##CGO_ENABLED=1 go build -x -ldflags "-X 'main.HAZA_NAME=$(NAME)' -X 'main.HAZA_VERSION=$(VERSION)' -linkmode external -extldflags=-static" -o $@ $(CMD_SRCS) + +$(NAME).debug: $(DATA) $(SRCS) $(CMD_DATA) $(CMD_SRCS) + CGO_ENABLED=1 go build -race -x -ldflags "-X 'main.HAZA_NAME=$(NAME)' -X 'main.HAZA_VERSION=$(VERSION)'" -o $@ $(CMD_SRCS) + +clean: + go clean -x -i + rm -f $(NAME) $(NAME).debug + +check: + go test -x + +cmd/tls.crt: + openssl req -x509 -newkey rsa:4096 -keyout cmd/tls.key -out cmd/tls.crt -sha256 -days 36500 -nodes -subj "/CN=$(NAME)" --addext "subjectAltName=DNS:$(NAME),IP:127.0.0.1,IP:::1" + +cmd/tls.key: + openssl req -x509 -newkey rsa:4096 -keyout cmd/tls.key -out cmd/tls.crt -sha256 -days 36500 -nodes -subj "/CN=$(NAME)" --addext "subjectAltName=DNS:$(NAME),IP:127.0.0.1,IP:::1" + +cmd/rsa.key: + openssl genrsa -traditional -out cmd/rsa.key 2048 + +.PHONY: all clean test diff --git a/cmd/logger.go b/cmd/logger.go new file mode 100644 index 0000000..7b16866 --- /dev/null +++ b/cmd/logger.go @@ -0,0 +1,272 @@ +package main + +import "fmt" +import "haza" +import "io" +import "os" +import "path/filepath" +import "runtime" +import "strings" +import "sync" +import "sync/atomic" +import "syscall" +import "time" + +type app_logger_msg_t struct { + code int + data string +} + +type AppLogger struct { + id string + out io.Writer + mask haza.LogMask + + file *os.File + file_name string // you can get the file name from file but this is to preserve the original. + file_rotate int + file_max_size int64 + msg_chan chan app_logger_msg_t + wg sync.WaitGroup + + use_color bool + closed atomic.Bool +} + +func _is_ansi_tty(fd uintptr) bool { + var st syscall.Stat_t + var err error + + err = syscall.Fstat(int(fd), &st) + if err != nil { return false } + if (st.Mode & syscall.S_IFMT) == syscall.S_IFCHR { + var term string + // i assume this fd is bound to the current terminal if it's a character device + // if the assumption is wrong, you simply get extraneous ansi code in the output. + term = os.Getenv("TERM") + if term != "" && term != "dumb" { return true } + } + + return false +} + + +func NewAppLogger(id string, w io.Writer, mask haza.LogMask) *AppLogger { + var l *AppLogger + var f *os.File + var ok bool + var use_color bool + + use_color = false + f, ok = w.(*os.File) + if ok { use_color = _is_ansi_tty(f.Fd()) } + + l = &AppLogger{ + id: id, + out: w, + mask: mask, + msg_chan: make(chan app_logger_msg_t, 256), + use_color: use_color, + } + l.closed.Store(false) + l.wg.Add(1) + go l.logger_task() + return l +} + +func NewAppLoggerToFile(id string, file_name string, max_size int64, rotate int, mask haza.LogMask) (*AppLogger, error) { + var l *AppLogger + var f *os.File + var matched bool + var err error + + f, err = os.OpenFile(file_name, os.O_CREATE | os.O_APPEND | os.O_WRONLY, 0666) + if err != nil { return nil, err } + + if os.PathSeparator == '/' { + // this check is performed only on systems where the path separator is /. + matched, _ = filepath.Match("/dev/*", file_name) + if matched { + // if the log file is under /dev, disable rotation + max_size = 0 + rotate = 0 + } + } + + l = &AppLogger{ + id: id, + out: f, + mask: mask, + file: f, + file_name: file_name, + file_max_size: max_size, + file_rotate: rotate, + msg_chan: make(chan app_logger_msg_t, 256), + use_color: _is_ansi_tty(f.Fd()), + } + l.closed.Store(false) + l.wg.Add(1) + go l.logger_task() + return l, nil +} + +func (l *AppLogger) Close() { + if l.closed.CompareAndSwap(false, true) { + l.msg_chan <- app_logger_msg_t{code: 1} + l.wg.Wait() + if l.file != nil { l.file.Close() } + } +} + +func (l *AppLogger) Rotate() { + l.msg_chan <- app_logger_msg_t{code: 2} +} + +func (l *AppLogger) logger_task() { + var msg app_logger_msg_t + defer l.wg.Done() + +main_loop: + for { + select { + case msg = <-l.msg_chan: + if msg.code == 0 { + //l.out.Write([]byte(msg)) + io.WriteString(l.out, msg.data) + if l.file_max_size > 0 && l.file != nil { + var fi os.FileInfo + var err error + fi, err = l.file.Stat() + if err == nil && fi.Size() >= l.file_max_size { + l.rotate() + } + } + } else if msg.code == 1 { + break main_loop + } else if msg.code == 2 { + l.rotate() + } + // other code must not appear here. + } + } +} + +func (l *AppLogger) Write(id string, level haza.LogLevel, fmtstr string, args ...interface{}) { + if l.mask & haza.LogMask(level) == 0 { return } + l.write(id, level, 1, fmtstr, args...) +} + +func (l *AppLogger) WriteWithCallDepth(id string, level haza.LogLevel, call_depth int, fmtstr string, args ...interface{}) { + if l.mask & haza.LogMask(level) == 0 { return } + l.write(id, level, call_depth + 1, fmtstr, args...) +} + +func (l *AppLogger) write(id string, level haza.LogLevel, call_depth int, fmtstr string, args ...interface{}) { + var now time.Time + var off_m int + var off_h int + var off_s int + var msg string + var callerfile string + var caller_line int + var caller_ok bool + var sb strings.Builder + + //if l.mask & haza.LogMask(level) == 0 { return } + + now = time.Now() + + _, off_s = now.Zone() + off_m = off_s / 60 + off_h = off_m / 60 + off_m = off_m % 60 + if off_m < 0 { off_m = -off_m } + + sb.WriteString( + fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d %+03d%02d ", + now.Year(), now.Month(), now.Day(), now.Hour(), now.Minute(), now.Second(), off_h, off_m)) + + _, callerfile, caller_line, caller_ok = runtime.Caller(1 + call_depth) + + if caller_ok { + sb.WriteString(fmt.Sprintf("[%s:%d] ", filepath.Base(callerfile), caller_line)) + } + sb.WriteString(l.id) + if id != "" { + sb.WriteString("(") + sb.WriteString(id) + sb.WriteString(")") + } + sb.WriteString(": ") + msg = fmt.Sprintf(fmtstr, args...) + if (l.use_color) { + var code string + code = l.log_level_to_ansi_code(level) + sb.WriteString(code) + sb.WriteString(msg) + if code != "" { sb.WriteString("\x1B[0m") } + } else { + sb.WriteString(msg) + } + if msg[len(msg) - 1] != '\n' { sb.WriteRune('\n') } + + // use queue to avoid blocking operation as much as possible + l.msg_chan <- app_logger_msg_t{ code: 0, data: sb.String() } +} + +func (l *AppLogger) rotate() { + var f *os.File + var fi os.FileInfo + var i int + var last_rot_no int + var err error + + if l.file == nil { return } + if l.file_rotate <= 0 { return } + + fi, err = l.file.Stat() + if err == nil && fi.Size() <= 0 { return } + + for i = l.file_rotate - 1; i > 0; i-- { + if os.Rename(fmt.Sprintf("%s.%d", l.file_name, i), fmt.Sprintf("%s.%d", l.file_name, i + 1)) == nil { + if last_rot_no == 0 { last_rot_no = i + 1 } + } + } + if os.Rename(l.file_name, fmt.Sprintf("%s.%d", l.file_name, 1)) == nil { + if last_rot_no == 0 { last_rot_no = 1 } + } + + f, err = os.OpenFile(l.file_name, os.O_CREATE | os.O_TRUNC | os.O_APPEND | os.O_WRONLY, 0666) + if err != nil { + l.file.Close() + l.file = nil + l.out = os.Stderr + // don't reset l.file_name. you can derive that there was an error + // if l.file_name is not blank, and if l.out is os.Stderr, + } else { + l.file.Close() + l.file = f + l.out = l.file + } +} + +func (l* AppLogger) log_level_to_ansi_code(level haza.LogLevel) string { + switch level { + case haza.LOG_ERROR: + return "\x1B[31m" // red + + case haza.LOG_WARN: + return "\x1B[33m" // yellow + + case haza.LOG_INFO: + if (l.mask & haza.LogMask(haza.LOG_DEBUG)) != 0 { + // if debug is enabled, change the color of info. + // otherwisse no color + return "\x1B[32m" // green + } + fallthrough + + default: + return "" + } +} diff --git a/cmd/main.go b/cmd/main.go new file mode 100644 index 0000000..cbcde6f --- /dev/null +++ b/cmd/main.go @@ -0,0 +1,59 @@ +package main + +import "context" +import "fmt" +import "haza" +import "net" +import "os" +//import "time" + +func main() { + var ds *haza.DhcpServer + var logger *AppLogger + var addr *net.UDPAddr + var addr2 *net.UDPAddr + var err error + +// addr, err = net.ResolveUDPAddr("udp4", "0.0.0.0:10000") + addr, err = net.ResolveUDPAddr("udp4", "192.168.1.130:10000") + if err != nil { + fmt.Printf("Failed to resolve address - %s\n", err.Error()) + os.Exit(-1) + } + + addr2, err = net.ResolveUDPAddr("udp4", "192.168.1.114:10000") + if err != nil { + fmt.Printf("Failed to resolve address - %s\n", err.Error()) + os.Exit(-1) + } + + logger = NewAppLogger("client", os.Stderr, haza.LOG_ALL) + + ds, err = haza.NewDhcpServer(context.Background(), "haza-dhcpd", logger) + if err != nil { + fmt.Printf("Failed to create dhcp server - %s\n", err.Error()) + goto oops + } + +// consider this if you want some more freedom on supported systems +// sysctl -w net.ipv4.ip_nonlocal_bind=1 + +//ds.AddListener("tun0", addr) +err = ds.AddListener4("enp1s0", addr) +if err != nil { + fmt.Printf("Failed to add listgener for %v - %s\n", addr, err.Error()) + goto oops +} +err = ds.AddListener4("enp1s0", addr2) +if err != nil { + fmt.Printf("Failed to add listgener for %v - %s\n", addr2, err.Error()) + goto oops +} + + ds.StartService(nil) + ds.WaitForTermination() + os.Exit(0) + +oops: + os.Exit(-1) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..c9872b9 --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module haza + +go 1.24.2 + +require ( + golang.org/x/net v0.43.0 // indirect + golang.org/x/sys v0.35.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4ae27a4 --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= diff --git a/haza.go b/haza.go new file mode 100644 index 0000000..c6bcb4f --- /dev/null +++ b/haza.go @@ -0,0 +1,77 @@ +package haza + +import "sync" +import "unsafe" + +type LogLevel int +type LogMask int + +const ( + LOG_DEBUG LogLevel = 1 << iota + LOG_INFO + LOG_WARN + LOG_ERROR +) + +const LOG_ALL LogMask = LogMask(LOG_DEBUG | LOG_INFO | LOG_WARN | LOG_ERROR) +const LOG_NONE LogMask = LogMask(0) + +type Named struct { + name string +} + +type Logger interface { + Write(id string, level LogLevel, fmtstr string, args ...interface{}) + WriteWithCallDepth(id string, level LogLevel, call_depth int, fmtstr string, args ...interface{}) + Rotate() + Close() +} + +type Service interface { + RunTask(wg *sync.WaitGroup) // blocking. run the actual task loop. it must call wg.Done() upon exit from itself. + StartService(data interface{}) // non-blocking. spin up a service. it may be invokded multiple times for multiple instances + StopServices() // non-blocking. send stop request to all services spun up + FixServices() // do some fixup as needed + WaitForTermination() // blocking. must wait until all services are stopped + WriteLog(id string, level LogLevel, fmtstr string, args ...interface{}) +} + +// --------------------------------------------------------- + +func (n *Named) SetName(name string) { + n.name = name +} + +func (n *Named) Name() string { + return n.name +} + +// --------------------------------------------------------- + +var _is_big_endian bool = is_big_endian() + +func is_big_endian() bool { + var v uint16 + v = 1 + return *(*byte)(unsafe.Pointer(&v)) == 0; +} + +func Hton16(v uint16) uint16 { + if _is_big_endian { return v } + return (v << 8) | (v >> 8) +} + +func Ntoh16(v uint16) uint16 { + if _is_big_endian { return v } + return (v << 8) | (v >> 8) +} + +func Hton32(v uint32) uint32 { + if _is_big_endian { return v } + return ((v >> 24) & 0xFF) | ((v >> 8) & 0xFF00) | ((v << 8) & 0xFF0000) | ((v << 24) & 0xFF000000) +} + +func Ntoh32(v uint32) uint32 { + if _is_big_endian { return v } + return ((v >> 24) & 0xFF) | ((v >> 8) & 0xFF00) | ((v << 8) & 0xFF0000) | ((v << 24) & 0xFF000000) +} diff --git a/iface.go b/iface.go new file mode 100644 index 0000000..f33a5e5 --- /dev/null +++ b/iface.go @@ -0,0 +1,398 @@ +package haza + +import "encoding/binary" +import "fmt" +import "net" +import "sync" + +import "golang.org/x/net/bpf" +import "golang.org/x/sys/unix" + +type Dhcp4UdpSock struct { + addr *net.UDPAddr + fd int +} + +type Dhcp4Conn struct { + fd int + iface string + + udp_socks []*Dhcp4UdpSock + mtx sync.Mutex +} + +func NewDhcp4Conn(iface string, addr *net.UDPAddr) (*Dhcp4Conn, error) { + var fd int + var fd2 int + var proto uint16 + var d *Dhcp4Conn + var udp_socks []*Dhcp4UdpSock + var ip net.IP + var err error + + fd = -1 + fd2 = -1 + + ip = addr.IP.To4() + if ip == nil { + err = fmt.Errorf("invalid address - %v", addr) + goto oops + } + + // Eventfd is linux specific. TODO: port it to other OSes. + proto = Hton16(unix.ETH_P_IP) + //proto = Hton16(unix.ETH_P_ALL) + fd, err = unix.Socket(unix.AF_PACKET, unix.SOCK_RAW | unix.SOCK_CLOEXEC, int(proto)) + if err != nil { goto oops } + + fd2, err = open_udp4_socket(addr) + if err != nil { goto oops } + + if iface != "" { + var sll *unix.SockaddrLinklayer + var nif *net.Interface + + nif, err = net.InterfaceByName(iface) + if err != nil { goto oops } + +// TODO: if the interface is not an ethernet device, fail it. +/* +tun0 +Interface type -> Encapsulation Type: Raw IP(7) + +eth0 +Encapsulation type: Ethernet(1) + +unix.SockaddrLinklayer -> Hatype + +/sys/class/net/tun0/type + -> 65534 (ARPHRD_NONE) + +/sys/class/net/eth0/type + -> 1 (ARPHRD_ETHER) + +#define ARPHRD_NETROM 0 * from KA9Q: NETROM pseudo * +#define ARPHRD_ETHER 1 * Ethernet 10Mbps * +#define ARPHRD_EETHER 2 * Experimental Ethernet * +#define ARPHRD_AX25 3 * AX.25 Level 2 * +#define ARPHRD_PRONET 4 * PROnet token ring * +#define ARPHRD_CHAOS 5 * Chaosnet * +#define ARPHRD_IEEE802 6 * IEEE 802.2 EthernetTRTB * +#define ARPHRD_ARCNET 7 * ARCnet * +#define ARPHRD_APPLETLK 8 * APPLEtalk * +#define ARPHRD_DLCI 15 * Frame Relay DLCI * +#define ARPHRD_ATM 19 * ATM * +#define ARPHRD_METRICOM 23 * Metricom STRIP (new IANA id) * +#define ARPHRD_IEEE1394 24 * IEEE 1394 IPv4 - RFC 2734 * +#define ARPHRD_EUI64 27 * EUI-64 * +#define ARPHRD_INFINIBAND 32 * InfiniBand * + +.... +#define ARPHRD_VOID 0xFFFF * Void type, nothing is known * +#define ARPHRD_NONE 0xFFFE * zero header length * +*/ + + sll = &unix.SockaddrLinklayer{ + Protocol: proto, + Ifindex: nif.Index, + } + err = unix.Bind(fd, sll) + if err != nil { goto oops } + } + + udp_socks = make([]*Dhcp4UdpSock, 1) + udp_socks[0] = &Dhcp4UdpSock{ fd: fd2, addr: addr } + + err = attach_dhcp4_filter(fd, iface, udp_socks) + if err != nil { goto oops } + + d = &Dhcp4Conn{fd: fd, iface: iface, udp_socks: udp_socks} + return d, nil + +oops: + if fd2 >= 0 { unix.Close(fd2) } + if fd >= 0 { unix.Close(fd) } + return nil, err +} + +func (d *Dhcp4Conn) Close() { + var us *Dhcp4UdpSock + + for _, us = range d.udp_socks { + if us.fd >= 0 { + unix.Close(us.fd) + us.fd = -1 + } + } + + if d.fd >= 0 { + unix.Close(d.fd) + d.fd = -1 + } +} + +func (d *Dhcp4Conn) ForEachUdpSockFd(handler func (fd int)) { + var us *Dhcp4UdpSock + + for _, us = range d.udp_socks { + if us.fd >= 0 { + // the handler must not close the fd + handler(us.fd) + } + } +} + +func (d *Dhcp4Conn) AddAddr(addr *net.UDPAddr) error { + var fd int + var udp_socks []*Dhcp4UdpSock + var err error + + fd, err = open_udp4_socket(addr) + if err != nil { return err } + + err = attach_discard_filter(fd) + if err != nil { return err } + + d.mtx.Lock() + defer d.mtx.Unlock() + + udp_socks = make([]*Dhcp4UdpSock, len(d.udp_socks) + 1) + copy(udp_socks, d.udp_socks) + udp_socks[len(d.udp_socks)] = &Dhcp4UdpSock{fd: fd, addr: addr} + + err = attach_dhcp4_filter(d.fd, d.iface, udp_socks) + if err != nil { + unix.Close(fd) + return err + } + + d.udp_socks = udp_socks + return nil +} + +/* +func (d *Dhcp4Conn) DelAddr(addr *net.UDPAddr) error { + return nil +}*/ + +func sockaddrToString(sa unix.Sockaddr) string { + switch v := sa.(type) { + case *unix.SockaddrInet4: + ip := net.IPv4(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3]) + return fmt.Sprintf("%s:%d", ip.String(), v.Port) + case *unix.SockaddrInet6: + ip := net.IP(v.Addr[:]) + return fmt.Sprintf("[%s]:%d", ip.String(), v.Port) + case *unix.SockaddrUnix: + return v.Name + case *unix.SockaddrLinklayer: + return fmt.Sprintf("ifindex=%d, proto=0x%04x addr % x", v.Ifindex, Ntoh16(v.Protocol), v.Addr[:v.Halen]) + default: + return fmt.Sprintf("unknown sockaddr type: %T", v) + } +} + +func patch_jump_if_skip_true(insts *[]bpf.Instruction, i uint8, offset uint8) { + var inst bpf.JumpIf + inst = (*insts)[i].(bpf.JumpIf) + inst.SkipTrue = offset + (*insts)[i] = inst +} + +func patch_jump_if_skip_false(insts *[]bpf.Instruction, i uint8, offset uint8) { + var inst bpf.JumpIf + inst = (*insts)[i].(bpf.JumpIf) + inst.SkipFalse = offset + (*insts)[i] = inst +} + +func attach_discard_filter(fd int) error { + var insts []bpf.Instruction + var raw_insts []bpf.RawInstruction + var flts []unix.SockFilter + var fprog *unix.SockFprog + var i int + var err error + + insts = []bpf.Instruction{ + bpf.RetConstant{Val: 0x0}, + } + + raw_insts, err = bpf.Assemble(insts) + if err != nil { goto oops } + + flts = make([]unix.SockFilter, len(raw_insts)) + for i, _ = range raw_insts { + flts[i] = unix.SockFilter{ + Code: raw_insts[i].Op, + Jt: raw_insts[i].Jt, + Jf: raw_insts[i].Jf, + K: raw_insts[i].K, + } + } + + fprog = &unix.SockFprog{ + Len: uint16(len(flts)), + Filter: &flts[0], + } + err = unix.SetsockoptSockFprog(fd, unix.SOL_SOCKET, unix.SO_ATTACH_FILTER, fprog) + if err != nil { goto oops } + + return nil + + +oops: + return err +} + +func attach_dhcp4_filter(fd int, iface string, udp_socks []*Dhcp4UdpSock) error { + var insts []bpf.Instruction + var raw_insts []bpf.RawInstruction + var flts []unix.SockFilter + var fprog *unix.SockFprog + var us *Dhcp4UdpSock + var i int + var ninsts int + var ipos uint8 + var accept_pos uint8 + var reject_pos uint8 + var err error + + ninsts = 8; + for _, us = range udp_socks { + if us.addr.IP.To4() == nil { break } + + if !us.addr.IP.IsUnspecified() { + ninsts = ninsts + 1 + } + ninsts = ninsts + 3 + } + + ninsts = ninsts + 2 + if ninsts > 255 { + if iface != "" { + return fmt.Errorf("too many ip addresses on %s", iface) + } else { + return fmt.Errorf("too many ip addresses") + } + } + + insts = make([]bpf.Instruction, ninsts) + + // JumpIf instructions are patched below after ip address and port check instructions are produced + insts[0] = bpf.LoadAbsolute{Off: 12, Size: 2} // load ethernet type(2 bytes) at offset 12. + insts[1] = bpf.JumpIf{Cond: bpf.JumpEqual, Val: unix.ETH_P_IP, SkipTrue: 0, SkipFalse: 0} + insts[2] = bpf.LoadAbsolute{Off: 14 + 9, Size: 1} // load IP protocol (offset 14 + 9) + insts[3] = bpf.JumpIf{Cond: bpf.JumpEqual, Val: unix.IPPROTO_UDP, SkipTrue: 0, SkipFalse: 0} + insts[4] = bpf.LoadAbsolute{Off: 14 + 6, Size: 2} // load IP flags & fragment offset (2 bytes) + insts[5] = bpf.JumpIf{Cond: bpf.JumpBitsSet, Val: 0x1FFF, SkipTrue: 0, SkipFalse: 0} + insts[6] = bpf.LoadAbsolute{Off: 14 + 16, Size: 4} // destination ip address + insts[7] = bpf.JumpIf{Cond: bpf.JumpEqual, Val: 0xFFFFFFFF, SkipTrue: 0, SkipFalse: 0} + + ipos = 8 + for _, us = range udp_socks { + if us.addr.IP.To4() == nil { break } + + if !us.addr.IP.IsUnspecified() { + var b net.IP + // check more desination ip addresses + b = us.addr.IP.To4() // must not fail as the check was done already + insts[ipos] = bpf.JumpIf{Cond: bpf.JumpEqual, Val: binary.BigEndian.Uint32(b), SkipTrue: 0, SkipFalse: 3} // ip match. go to port match + ipos = ipos + 1 + } + + // compute IP header length (IHL) into X register. LDX is like this: + // +---------------+-----------------+---+---+---+ + // | AddrMode (3b) | LoadWidth (2b) | 0 | 0 | 1 | + // +---------------+-----------------+---+---+---+ + // this one is supposed to be LDX(0x1) | B(0x10) | MSH(0xA0) + insts[ipos] = bpf.RawInstruction{Op: (0x1 | 0x10 | 0xA0), Jt: 0, Jf: 0, K: 14 } + ipos = ipos + 1 + + // load UDP destination port: offset = Ethernet(14) + 2 (UDP dest port) + IP header length * 4 + insts[ipos] = bpf.LoadIndirect{Off: 14 + 2, Size: 2} + ipos = ipos + 1 + insts[ipos] = bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(us.addr.Port), SkipTrue: 0, SkipFalse: 0} + ipos = ipos + 1 + } + + insts[ipos] = bpf.RetConstant{Val: 0} // reject + reject_pos = ipos + ipos = ipos + 1 + insts[ipos] = bpf.RetConstant{Val: 0xFFFF} // accept + accept_pos = ipos + ipos = ipos + 1 + + patch_jump_if_skip_false(&insts, 1, reject_pos - 1 - 1) // discard if not ETH_P_IP + patch_jump_if_skip_false(&insts, 3, reject_pos - 3 - 1) // discard if not UDP + patch_jump_if_skip_true(&insts, 5, reject_pos - 5 - 1) // discard fragmented packets + patch_jump_if_skip_true(&insts, 7, accept_pos - 7 - 1) // accept broadcast + + ipos = 8 + for _, us = range udp_socks { + if us.addr.IP.To4() == nil { break } + + if !us.addr.IP.IsUnspecified() { + ipos = ipos + 1 + } + + ipos = ipos + 1 // skip bpf.RawInstruction + ipos = ipos + 1 // skip bpf.LoadIndirect + patch_jump_if_skip_true(&insts, ipos, accept_pos - ipos - 1) // accept if port match + ipos = ipos + 1 + } + + raw_insts, err = bpf.Assemble(insts) + if err != nil { goto oops } + + flts = make([]unix.SockFilter, len(raw_insts)) + for i, _ = range raw_insts { + flts[i] = unix.SockFilter{ + Code: raw_insts[i].Op, + Jt: raw_insts[i].Jt, + Jf: raw_insts[i].Jf, + K: raw_insts[i].K, + } + } + + fprog = &unix.SockFprog{ + Len: uint16(len(flts)), + Filter: &flts[0], + } + err = unix.SetsockoptSockFprog(fd, unix.SOL_SOCKET, unix.SO_ATTACH_FILTER, fprog) + if err != nil { goto oops } + + return nil + +oops: + return err +} + +func open_udp4_socket(addr *net.UDPAddr) (int, error) { + var fd int + var ip net.IP + var err error + + ip = addr.IP.To4() + if ip == nil { + err = fmt.Errorf("invalid address - %v", addr) + goto oops + } + + // create a real udp packet to deceive the kernel. + fd, err = unix.Socket(unix.AF_INET, unix.SOCK_DGRAM | unix.SOCK_CLOEXEC, unix.IPPROTO_UDP) + if err != nil { goto oops } + + // we won't be using this socket to receive packets. + // attach a filter that discard incoming packets + err = attach_discard_filter(fd) + if err != nil { goto oops } + + err = unix.Bind(fd, &unix.SockaddrInet4{ Addr: [4]byte(ip[0:4]), Port: addr.Port}) + if err != nil { goto oops } + + return fd, nil + +oops: + return -1, err +} diff --git a/iface_test.go b/iface_test.go new file mode 100644 index 0000000..71bb11b --- /dev/null +++ b/iface_test.go @@ -0,0 +1,34 @@ +package haza_test + +//import "fmt" +import "haza" +import "net" +import "testing" + +func TestDhcp4Conn(t *testing.T) { + var c *haza.Dhcp4Conn + var addr *net.UDPAddr + var err error + + addr, err = net.ResolveUDPAddr("udp6", "[::1]:1158") + if err != nil { + t.Errorf("failed to resolve address - %s\n", err.Error()) + } else { + c, err = haza.NewDhcp4Conn("", addr) + if err == nil { + t.Errorf("this must fail as v6 address is given to v4 conn - %v\n", addr) + } + } + + addr, err = net.ResolveUDPAddr("udp4", "127.0.0.1:1158") + if err != nil { + t.Errorf("failed to resolve address - %s\n", err.Error()) + } else { + c, err = haza.NewDhcp4Conn("", addr) + if err != nil { + t.Errorf("failed to create dhcp4 conn - %s\n", err.Error()) + } else { + c.Close() + } + } +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..04e9f5f --- /dev/null +++ b/server.go @@ -0,0 +1,304 @@ +package haza + +import "context" +import "errors" +import "fmt" +import "net" +import "runtime" +import "sync" +import "sync/atomic" +import "unsafe" + +import "golang.org/x/sys/unix" + +type DhcpServer struct { + Named + + Ctx context.Context + CtxCancel context.CancelFunc + + stop_req atomic.Bool + + conns_mtx sync.Mutex + conns map[string]*Dhcp4Conn + conns_by_fd map[int]*Dhcp4Conn + + p int // epoll + efd int // eventfd + wg sync.WaitGroup + + + log Logger +} + +func finalize_dhcp_server(s *DhcpServer) { + var dc *Dhcp4Conn + + for _, dc = range s.conns_by_fd { + + // the internal udp sockets wasn't addded to epoll. + // something liek the following isn't needed. + //dc.ForEachUdpSockFd(func (fd int) { + // unix.EpollCtl(s.p, unix.EPOLL_CTL_DEL, fd, nil) + //}) + + unix.EpollCtl(s.p, unix.EPOLL_CTL_DEL, dc.fd, nil) + delete(s.conns_by_fd, dc.fd) + delete(s.conns, dc.iface) + dc.Close() + } + + if s.efd >= 0 { + if s.p >= 0 { unix.EpollCtl(s.p, unix.EPOLL_CTL_DEL, s.efd, nil) } + unix.Close(s.efd) + s.efd = -1 + } + + if s.p >= 0 { + unix.Close(s.p) + s.p = -1 + } +} + +func NewDhcpServer(ctx context.Context, name string, logger Logger) (*DhcpServer, error) { + var s *DhcpServer + var p int + var efd int + var err error + + p, err = unix.EpollCreate1(0) + if err != nil { + return nil, fmt.Errorf("unable to create epoll - %s", err.Error()) + } + + efd, err = unix.Eventfd(0, unix.EFD_CLOEXEC) + if err != nil { + unix.Close(p) + return nil, err + } + + err = unix.EpollCtl(p, unix.EPOLL_CTL_ADD, efd, + &unix.EpollEvent{ + Events: unix.EPOLLIN | unix.EPOLLERR | unix.EPOLLHUP, + Fd: int32(efd), + }, + ) + if err != nil { + unix.Close(efd) + unix.Close(p) + return nil, err + } + + s = &DhcpServer{} + s.Ctx, s.CtxCancel = context.WithCancel(ctx) + s.name = name + s.log = logger + s.stop_req.Store(false) + s.efd = efd + s.p = p + s.conns = make(map[string]*Dhcp4Conn, 0) + s.conns_by_fd = make(map[int]*Dhcp4Conn, 0) + + runtime.SetFinalizer(s, finalize_dhcp_server) + return s, nil +} + +func (s *DhcpServer) AddListener4(iface string, addr *net.UDPAddr) error { + var conn *Dhcp4Conn + var ok bool + var err error + + s.conns_mtx.Lock() + conn, ok = s.conns[iface] + if ok { + err = conn.AddAddr(addr) + if err != nil { + s.conns_mtx.Unlock() + return err + } + } else { + conn, err = NewDhcp4Conn(iface, addr) + if err != nil { +fmt.Printf ("fAILED TO CRETEA %s\n", err.Error()) + s.conns_mtx.Unlock() + return err + } +fmt.Printf ("ADDING dhcp4... %d\n", conn.fd) + err = unix.EpollCtl(s.p, unix.EPOLL_CTL_ADD, conn.fd, + &unix.EpollEvent{ + Events: unix.EPOLLIN | unix.EPOLLERR | unix.EPOLLHUP, + Fd: int32(conn.fd), + }, + ) + if err != nil { + s.conns_mtx.Unlock() + conn.Close() + return err + } + } + + s.conns[iface] = conn // this isn't needed??? or this must be a slice... + s.conns_by_fd[conn.fd] = conn + s.conns_mtx.Unlock() + return nil +} + +func (s *DhcpServer) RunTask(wg *sync.WaitGroup) { +// var l_wg sync.WaitGroup + + defer wg.Done() + +/* + for i = 0; i < 10; i++ { // TODO: configured number of handlersz... + l_wg.Add(1) + go s.run_handlers(&l_wg) + } +*/ +fmt.Printf ("beginning of recv loop\n") + s.run_recv_loop() +fmt.Printf ("end of recv loop\n") + +// l_wg.Wait() + finalize_dhcp_server(s) +} + +func (s *DhcpServer) ReqStop() { + if s.stop_req.CompareAndSwap(false, true) { + var v uint64 + // eventfd needs an 8-byte integer. + v = 1 + unix.Write(s.efd, (*[8]byte)(unsafe.Pointer(&v))[:]) + } +} + +func (s *DhcpServer) StartService(data interface{}) { + s.wg.Add(1) + go s.RunTask(&s.wg) +} + +func (s *DhcpServer) StopServices() { +// var ext_svc Service + s.ReqStop() +// s.ext_mtx.Lock() +// for _, ext_svc = range s.ext_svcs { +// ext_svc.StopServices() +// } +// s.ext_closed = true +// s.ext_mtx.Unlock() +} + +func (s *DhcpServer) FixServices() { + s.log.Rotate() +} + +func (s *DhcpServer) WaitForTermination() { + s.wg.Wait() + s.log.Write("", LOG_INFO, "End of service") +} + +func (s *DhcpServer) WriteLog(id string, level LogLevel, fmtstr string, args ...interface{}) { + s.log.Write(id, level, fmtstr, args...) +} + + + + + +func get_offsets(frame []byte) (int, int, int, int, int, error) { + if len(frame) < 14 { + return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for Ethernet header") + } + ethOff := 0 + ipOff := ethOff + 14 + + if len(frame) < ipOff+20 { + return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for IPv4 header") + } + + // First byte of IPv4 header: 4 bits version, 4 bits IHL + ihl := int(frame[ipOff] & 0x0F) * 4 + if len(frame) < ipOff+ihl { + return 0, 0, 0, 0, 0, fmt.Errorf("invalid IHL: %d", ihl) + } + + udpOff := ipOff + ihl + if len(frame) < udpOff+8 { + return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for UDP header") + } + + // UDP length field (2 bytes at offset 4–5 of UDP header) + udpLen := int(frame[udpOff+4])<<8 | int(frame[udpOff+5]) + if udpLen < 8 { + return 0, 0, 0, 0, 0, fmt.Errorf("invalid UDP length: %d", udpLen) + } + if len(frame) < udpOff+udpLen { + return 0, 0, 0, 0, 0, fmt.Errorf("frame too short for UDP length %d", udpLen) + } + + dhcpOff := udpOff + 8 + dhcpEnd := udpOff + udpLen + + return ethOff, ipOff, udpOff, dhcpOff, dhcpEnd, nil +} + +func (s *DhcpServer) run_recv_loop() error { + var buf [1500]byte + var nevts int + var i int + var evts [128]unix.EpollEvent + var err error + +epoll_loop: + for { + nevts, err = unix.EpollWait(s.p, evts[:], -1) + if err != nil { + if errors.Is(err, unix.EINTR) { continue } +fmt.Printf ("epoll error - %s\n", err.Error()) + break + } + + for i = 0; i < nevts ; i++ { + if (evts[i].Events & (unix.EPOLLHUP | unix.EPOLLERR)) != 0 { +fmt.Printf ("epoll something not right...\n") + break epoll_loop + } + + if int(evts[i].Fd) == s.efd { + if ((evts[i].Events & unix.EPOLLIN) != 0) { + unix.Read(s.efd, buf[:]) + } + break epoll_loop + } else { + if (evts[i].Events & unix.EPOLLIN) != 0 { + var n int + var from unix.Sockaddr + + n, from, err = unix.Recvfrom(int(evts[i].Fd), buf[:], 0) + fmt.Printf("%d %v [% x] %v\n", n, sockaddrToString(from), buf[:n], err) + if n > 0 { + var spos int + var epos int + _, _, _, spos, epos, err = get_offsets(buf[:n]) + if err != nil { + fmt.Printf ("PACKET -> %s\n", err.Error()) + } else { + //fmt.Printf (">>%d [%s]<<\n", epos - spos, string(buf[spos: epos])) + if string(buf[spos:epos]) == "quit\n" { + break epoll_loop + } + } + } + + if err != nil { + // TODO: logging + fmt.Printf("ERROR... %s\n", err.Error()) + break epoll_loop + } + } + } + } + } + + // if there are subtasks wait here for termination and close the dockets + return nil +}