Files
hodu/client-cts-rxc.go

689 lines
17 KiB
Go

package hodu
import "bytes"
import "encoding/gob"
import "errors"
import "fmt"
import "io"
import "os"
import "os/exec"
import "os/user"
import "strconv"
import "sync"
import "syscall"
import "unicode"
import "unicode/utf8"
import "golang.org/x/sys/unix"
func (cts *ClientConn) FindClientRxcById(id uint64) *ClientRxc {
var crp *ClientRxc
var ok bool
cts.rxc_mtx.Lock()
crp, ok = cts.rxc_map[id]
cts.rxc_mtx.Unlock()
if !ok { crp = nil }
return crp
}
func (cts *ClientConn) RxcLoop(crp *ClientRxc, wg *sync.WaitGroup) {
var poll_fds []unix.PollFd
var buf [2048]byte
var n int
var out_revents int16
var err_revents int16
var sig_revents int16
var stop_flags uint64
var stop_msg string
var stdout_open bool
var stderr_open bool
var err error
defer wg.Done()
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Started rxc(%d) loop for %s(%s) - %s", crp.id, crp.req_type, crp.req_script, crp.cmd.String())
cts.C.stats.rxc_sessions.Add(1)
stdout_open = true
stderr_open = true
poll_fds = []unix.PollFd{
unix.PollFd{Fd: int32(crp.stdout.Fd()), Events: unix.POLLIN},
unix.PollFd{Fd: int32(crp.stderr.Fd()), Events: unix.POLLIN},
unix.PollFd{Fd: int32(crp.pfd[0]), Events: unix.POLLIN},
}
for {
n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely
if err != nil {
if errors.Is(err, unix.EINTR) { continue }
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to poll rxc(%d) stdout - %s", crp.id, err.Error())
break
}
if n == 0 { continue } // timed out
out_revents = poll_fds[0].Revents
err_revents = poll_fds[1].Revents
sig_revents = poll_fds[2].Revents
if stdout_open && (out_revents & unix.POLLIN) != 0 {
n, err = crp.stdout.Read(buf[:])
if n > 0 {
var err2 error
err2 = cts.psc.Send(MakeRxcDataPacket(crp.id, RXC_DATA_FLAG_NONE, buf[:n]))
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send %s from rxc(%d) stdout to server - %s", PACKET_KIND_RXC_DATA.String(), crp.id, err2.Error())
break
}
}
if err != nil {
if !errors.Is(err, io.EOF) {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read rxc(%d) stdout - %s", crp.id, err.Error())
break
}
poll_fds[0].Fd = -1
poll_fds[0].Events = 0
stdout_open = false
}
}
if stderr_open && (err_revents & unix.POLLIN) != 0 {
n, err = crp.stderr.Read(buf[:])
if n > 0 {
var err2 error
err2 = cts.psc.Send(MakeRxcDataPacket(crp.id, RXC_DATA_FLAG_STDERR, buf[:n]))
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send %s from rxc(%d) stderr to server - %s", PACKET_KIND_RXC_DATA.String(), crp.id, err2.Error())
break
}
}
if err != nil {
if !errors.Is(err, io.EOF) {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read rxc(%d) stderr - %s", crp.id, err.Error())
break
}
poll_fds[1].Fd = -1
poll_fds[1].Events = 0
stderr_open = false
}
}
if (sig_revents & unix.POLLIN) != 0 {
// don't care to read the pipe as it is closed after the loop
//unix.Read(crp.pfd[0], )
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Stop request noticed on rxc(%d) signal pipe", crp.id)
break
}
if stdout_open && (out_revents & (unix.POLLERR | unix.POLLNVAL)) != 0 {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Error detected on rxc(%d) stdout", crp.id)
break
}
if stderr_open && (err_revents & (unix.POLLERR | unix.POLLNVAL)) != 0 {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Error detected on rxc(%d) stderr", crp.id)
break
}
if sig_revents & (unix.POLLERR | unix.POLLHUP | unix.POLLNVAL) != 0 {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "EOF detected on rxc(%d) signal pipe", crp.id)
break
}
if stdout_open && (out_revents & unix.POLLHUP) != 0 && (out_revents & unix.POLLIN) == 0 {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "EOF detected on rxc(%d) stdout", crp.id)
poll_fds[0].Fd = -1
poll_fds[0].Events = 0
stdout_open = false
}
if stderr_open && (err_revents & unix.POLLHUP) != 0 && (err_revents & unix.POLLIN) == 0 {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "EOF detected on rxc(%d) stderr", crp.id)
poll_fds[1].Fd = -1
poll_fds[1].Events = 0
stderr_open = false
}
if !stdout_open && !stderr_open { break }
}
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Ending rxc(%d) loop", crp.id)
crp.ReqStop()
crp.stdin.Close() // close the input before waiting for program termination
err = crp.cmd.Wait()
if err != nil && crp.cmd.ProcessState == nil {
stop_msg = err.Error()
}
stop_flags = MakeRxcStopFlagsFromProcessState(crp.cmd.ProcessState)
err = cts.psc.Send(MakeRxcStopPacket(crp.id, stop_flags, stop_msg))
if err != nil {
cts.C.log.Write(cts.Sid, LOG_WARN, "Failed to send %s from rxc(%d) to server - %s", PACKET_KIND_RXC_STOP.String(), crp.id, err.Error())
}
crp.stderr.Close()
crp.stdout.Close()
unix.Close(crp.pfd[0])
unix.Close(crp.pfd[1])
cts.rxc_mtx.Lock()
delete(cts.rxc_map, crp.id)
cts.rxc_mtx.Unlock()
cts.C.stats.rxc_sessions.Add(-1)
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Ended rxc(%d) loop", crp.id)
}
func simple_command_escape_digit_value(ch rune, base int) int {
switch base {
case 8:
if ch >= '0' && ch <= '7' {
return int(ch - '0')
}
case 16:
switch {
case ch >= '0' && ch <= '9':
return int(ch - '0')
case ch >= 'a' && ch <= 'f':
return int(ch - 'a') + 10
case ch >= 'A' && ch <= 'F':
return int(ch - 'A') + 10
}
}
return -1
}
func parse_simple_command_escape_value(runes []rune, start int, base int, digit_count int, tag string) (rune, int, error) {
var ch rune
var i int
var value rune
var digit int
if start + digit_count > len(runes) {
return 0, 0, fmt.Errorf("truncated %s escape in simple command", tag)
}
value = 0
for i = 0; i < digit_count; i++ {
ch = runes[start + i]
digit = simple_command_escape_digit_value(ch, base)
if digit < 0 {
return 0, 0, fmt.Errorf("invalid %s escape in simple command", tag)
}
value = value * rune(base) + rune(digit)
}
if !utf8.ValidRune(value) {
return 0, 0, fmt.Errorf("invalid %s escape value in simple command", tag)
}
return value, digit_count, nil
}
func parse_simple_command_escape(runes []rune, start int, escape_level int) (rune, int, error) {
var ch rune
if start >= len(runes) {
return 0, 0, fmt.Errorf("missing escaped character in simple command")
}
ch = runes[start]
if escape_level <= 0 {
return ch, 0, nil
}
switch ch {
case 'a':
return '\a', 0, nil
case 'b':
return '\b', 0, nil
case 'f':
return '\f', 0, nil
case 'n':
return '\n', 0, nil
case 'r':
return '\r', 0, nil
case 't':
return '\t', 0, nil
case 'v':
return '\v', 0, nil
case '\\':
return '\\', 0, nil
case '"':
return '"', 0, nil
case '\'':
return '\'', 0, nil
}
if escape_level < 2 {
return ch, 0, nil
}
switch ch {
case 'x':
return parse_simple_command_escape_value(runes, start + 1, 16, 2, "\\x")
case 'o':
return parse_simple_command_escape_value(runes, start + 1, 8, 2, "\\o")
case 'u':
return parse_simple_command_escape_value(runes, start + 1, 16, 4, "\\u")
case 'U':
return parse_simple_command_escape_value(runes, start + 1, 16, 8, "\\U")
default:
return ch, 0, nil
}
}
func split_simple_command(script string, escape_level int) ([]string, error) {
var args []string
var runes []rune
var buf bytes.Buffer
var ch rune
var quote rune
var decoded rune
var i int
var skip_until int
var consumed int
var token_started bool
var escaped bool
var err error
args = make([]string, 0)
runes = []rune(script)
skip_until = -1
for i, ch = range runes {
if i < skip_until {
continue
}
if escaped {
if quote == '"' && escape_level > 0 {
decoded, consumed, err = parse_simple_command_escape(runes, i, escape_level)
if err != nil {
return nil, err
}
buf.WriteRune(decoded)
skip_until = i + consumed + 1
} else {
buf.WriteRune(ch)
skip_until = i + 1
}
token_started = true
escaped = false
continue
}
if ch == '\\' {
escaped = true
token_started = true
continue
}
if quote != 0 {
if ch == quote {
quote = 0
} else {
buf.WriteRune(ch)
token_started = true
}
continue
}
switch ch {
case '\'', '"':
quote = ch
token_started = true
default:
if unicode.IsSpace(ch) {
if token_started {
args = append(args, buf.String())
buf.Reset()
token_started = false
}
} else {
buf.WriteRune(ch)
token_started = true
}
}
}
if escaped { return nil, fmt.Errorf("dangling escape in simple command") }
if quote != 0 { return nil, fmt.Errorf("unterminated quote in simple command") }
if token_started { args = append(args, buf.String()) }
if len(args) <= 0 { return nil, fmt.Errorf("blank simple command") }
return args, nil
}
func expand_rxc_profile_arg(arg_expr string, input_args []string) (string, error) {
var buf bytes.Buffer
var r rune
var last_pos int
var pos int
var j int
var arg_idx_str string
var arg_idx int
var arg_expr_len int
var input_args_len int
var err error
input_args_len = len(input_args)
arg_expr_len = len(arg_expr)
last_pos = 0
for pos, r = range arg_expr { // use the for .. range expression for rune-based traversal
if r != '$' { continue }
buf.WriteString(arg_expr[last_pos:pos])
if pos + 1 >= arg_expr_len {
// nothing after $
buf.WriteByte('$')
last_pos = pos + 1
break
}
if arg_expr[pos + 1] == '$' {
// convert $$ to a literal $
buf.WriteByte('$')
last_pos = pos + 2
continue
}
if arg_expr[pos + 1] != '{' {
// $ not followed by {
buf.WriteByte('$')
last_pos = pos + 1
continue
}
j = pos + 2
for j < arg_expr_len && arg_expr[j] != '}' { j++ }
if j >= arg_expr_len {
return "", fmt.Errorf("unterminated rxc profile args expression %s", arg_expr)
}
arg_idx_str = arg_expr[pos + 2:j]
if arg_idx_str == "@" {
return "", fmt.Errorf("invalid use of ${@} in rxc profile args expression %s", arg_expr)
}
arg_idx, err = strconv.Atoi(arg_idx_str)
if err != nil || arg_idx <= 0 {
return "", fmt.Errorf("invalid rxc profile argument index ${%s}", arg_idx_str)
}
if arg_idx > input_args_len {
//return "", fmt.Errorf("rxc profile argument index ${%d} out of range", arg_idx)
// don't return failure. just carry on without writing anything
// write nothing to buf.
} else {
buf.WriteString(input_args[arg_idx - 1])
}
last_pos = j + 1
}
if last_pos < arg_expr_len { buf.WriteString(arg_expr[last_pos:]) }
return buf.String(), nil
}
func expand_rxc_profile_args(arg_exprs []string, input_args []string) ([]string, error) {
var args []string
var arg_expr string
args = make([]string, 0)
for _, arg_expr = range arg_exprs {
if arg_expr == "${@}" {
// ${@} must be used alone
args = append(args, input_args...)
} else {
var expanded string
var err error
// expand_rxc_profile_arg rejects ${@} as it doesn't allow
// to combine it with other elements (e.g. xx${@}yy is disallowed).
expanded, err = expand_rxc_profile_arg(arg_expr, input_args)
if err != nil { return nil, err }
args = append(args, expanded)
}
}
return args, nil
}
func apply_rxc_user(cmd *exec.Cmd, rxc_user string) error {
var uid int
var gid int
var u *user.User
var err error
u, err = user.Lookup(rxc_user)
if err != nil { return err }
uid, _ = strconv.Atoi(u.Uid)
gid, _ = strconv.Atoi(u.Gid)
cmd.SysProcAttr = &syscall.SysProcAttr{
Credential: &syscall.Credential{
Uid: uint32(uid),
Gid: uint32(gid),
},
Setsid: true,
}
cmd.Dir = u.HomeDir
cmd.Env = append(cmd.Env,
"HOME=" + u.HomeDir,
"LOGNAME=" + u.Username,
"PATH=" + os.Getenv("PATH"),
"USER=" + u.Username,
)
return nil
}
func connect_cmd(c *Client, type_ string, script string) (*exec.Cmd, *os.File, *os.File, *os.File, error) {
var cmd *exec.Cmd
var stdin io.WriteCloser
var stdout io.ReadCloser
var stderr io.ReadCloser
var stdin_f *os.File
var stdout_f *os.File
var stderr_f *os.File
var argv []string
var cmd_argv []string
var profile *ClientRxcProfile
var effective_user string
var ok bool
var err error
effective_user = c.rxc_user
if type_ == "simple" {
argv, err = split_simple_command(script, 0)
if err != nil {
return nil, nil, nil, nil, err
}
if argv[0] == "" {
return nil, nil, nil, nil, fmt.Errorf("blank simple command")
}
profile, err = c.ResolveRxcProfile(argv[0])
if err != nil {
return nil, nil, nil, nil, err
}
if profile == nil {
return nil, nil, nil, nil, fmt.Errorf("unknown rxc profile %s", argv[0])
}
if profile.User != "" { effective_user = profile.User }
if profile.Args != nil {
cmd_argv, err = expand_rxc_profile_args(profile.Args, argv[1:])
if err != nil {
return nil, nil, nil, nil, err
}
} else {
cmd_argv = argv[1:]
}
cmd = exec.Command(profile.Script, cmd_argv...)
} else {
return nil, nil, nil, nil, fmt.Errorf("unsupported type - %s", type_)
}
if effective_user != "" {
err = apply_rxc_user(cmd, effective_user)
if err != nil {
return nil, nil, nil, nil, err
}
}
stdout, err = cmd.StdoutPipe()
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("unable to get stdout pipe - %s", err.Error())
}
stdout_f, ok = stdout.(*os.File)
if !ok {
stdout.Close()
return nil, nil, nil, nil, fmt.Errorf("unsupported stdout pipe")
}
stderr, err = cmd.StderrPipe()
if err != nil {
stdout.Close()
return nil, nil, nil, nil, fmt.Errorf("unable to get stderr pipe - %s", err.Error())
}
stderr_f, ok = stderr.(*os.File)
if !ok {
stderr.Close()
stdout.Close()
return nil, nil, nil, nil, fmt.Errorf("unsupported stderr pipe")
}
stdin, err = cmd.StdinPipe()
if err != nil {
stderr.Close()
stdout.Close()
return nil, nil, nil, nil, fmt.Errorf("unable to get stdin pipe - %s", err.Error())
}
stdin_f, ok = stdin.(*os.File)
if !ok {
stdin.Close()
stderr.Close()
stdout.Close()
return nil, nil, nil, nil, fmt.Errorf("unsupported stdin pipe")
}
err = cmd.Start()
if err != nil {
stdin.Close()
stderr.Close()
stdout.Close()
return nil, nil, nil, nil, err
}
return cmd, stdin_f, stdout_f, stderr_f, nil
}
func (cts *ClientConn) StartRxc(id uint64, data []byte, wg *sync.WaitGroup) error {
var crp *ClientRxc
var ok bool
var i int
var dec *gob.Decoder
var args []string
var err error
var err2 error
dec = gob.NewDecoder(bytes.NewBuffer(data));
err = dec.Decode(&args);
if err != nil {
return fmt.Errorf("unable to decode data for start on rxc(%d) - %s", id, err.Error())
}
if len(args) != 2 {
return fmt.Errorf("invalid data for start on rxc(%d)", id)
}
cts.rxc_mtx.Lock()
_, ok = cts.rxc_map[id]
if ok {
cts.rxc_mtx.Unlock()
return fmt.Errorf("multiple start on rxc(%d)", id)
}
crp = &ClientRxc{ cts: cts, id: id, req_type: args[0], req_script: args[1] }
err = unix.Pipe(crp.pfd[:])
if err != nil {
cts.rxc_mtx.Unlock()
err2 = cts.psc.Send(MakeRxcStopPacket(id, 0, err.Error()))
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send %s for rxc(%d) start failure to server - %s", PACKET_KIND_RXC_STOP.String(), id, err2.Error())
}
return fmt.Errorf("unable to create rxc(%d) event fd for %s(%s) - %s", id, args[0], args[1], err.Error())
}
crp.cmd, crp.stdin, crp.stdout, crp.stderr, err = connect_cmd(cts.C, args[0], args[1])
if err != nil {
cts.rxc_mtx.Unlock()
err2 = cts.psc.Send(MakeRxcStopPacket(id, 0, err.Error()))
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send %s for rxc(%d) start failure to server - %s", PACKET_KIND_RXC_STOP.String(), id, err2.Error())
}
unix.Close(crp.pfd[0])
unix.Close(crp.pfd[1])
return fmt.Errorf("unable to start rxc(%d) for %s(%s) - %s", id, args[0], args[1], err.Error())
}
for i = 0; i < 2; i++ {
var flags int
flags, err = unix.FcntlInt(uintptr(crp.pfd[i]), unix.F_GETFL, 0)
if err == nil {
unix.FcntlInt(uintptr(crp.pfd[i]), unix.F_SETFL, flags | unix.O_NONBLOCK)
}
}
cts.rxc_map[id] = crp
wg.Add(1)
go cts.RxcLoop(crp, wg)
cts.rxc_mtx.Unlock()
return nil
}
func (cts *ClientConn) StopRxc(id uint64) error {
var crp *ClientRxc
crp = cts.FindClientRxcById(id)
if crp == nil {
return fmt.Errorf("unknown rxc id %d", id)
}
crp.ReqStop()
return nil
}
func (cts *ClientConn) WriteRxc(id uint64, data []byte) error {
var crp *ClientRxc
crp = cts.FindClientRxcById(id)
if crp == nil {
return fmt.Errorf("unknown rxc id %d", id)
}
// TODO: what if stdin can't be written fast enough and gets blocked?
crp.stdin.Write(data)
return nil
}
func (cts *ClientConn) HandleRxcEvent(packet_type PACKET_KIND, evt *RxcEvent) error {
switch packet_type {
case PACKET_KIND_RXC_START:
return cts.StartRxc(evt.Id, evt.Data, &cts.C.wg)
case PACKET_KIND_RXC_STOP:
return cts.StopRxc(evt.Id)
case PACKET_KIND_RXC_DATA:
return cts.WriteRxc(evt.Id, evt.Data)
}
// ignore other packet types
return nil
}