1661 lines
47 KiB
Go
1661 lines
47 KiB
Go
package hodu
|
|
|
|
import "context"
|
|
import "crypto/tls"
|
|
import "errors"
|
|
import "fmt"
|
|
import "io"
|
|
import "log"
|
|
import "net"
|
|
import "net/http"
|
|
import "net/netip"
|
|
import "strconv"
|
|
import "sync"
|
|
import "sync/atomic"
|
|
import "time"
|
|
import "unsafe"
|
|
|
|
import "golang.org/x/net/websocket"
|
|
import "google.golang.org/grpc"
|
|
import "google.golang.org/grpc/credentials"
|
|
//import "google.golang.org/grpc/metadata"
|
|
import "google.golang.org/grpc/peer"
|
|
import "google.golang.org/grpc/stats"
|
|
|
|
const PTS_LIMIT int = 16384
|
|
const CTS_LIMIT int = 16384
|
|
|
|
type PortId uint16
|
|
const PORT_ID_MARKER string = "_"
|
|
|
|
type ServerConnMapByAddr = map[net.Addr]*ServerConn
|
|
type ServerConnMap = map[ConnId]*ServerConn
|
|
type ServerRouteMap = map[RouteId]*ServerRoute
|
|
type ServerPeerConnMap = map[PeerId]*ServerPeerConn
|
|
type ServerSvcPortMap = map[PortId]ConnRouteId
|
|
|
|
type ServerWpxResponseTransformer func(r *ServerRouteProxyInfo, resp *http.Response) io.Reader
|
|
type ServerWpxForeignPortProxyMaker func(wpx_type string, port_id string) (*ServerRouteProxyInfo, error)
|
|
|
|
type Server struct {
|
|
ctx context.Context
|
|
ctx_cancel context.CancelFunc
|
|
pxytlscfg *tls.Config
|
|
wpxtlscfg *tls.Config
|
|
ctltlscfg *tls.Config
|
|
rpctlscfg *tls.Config
|
|
|
|
wg sync.WaitGroup
|
|
stop_req atomic.Bool
|
|
stop_chan chan bool
|
|
|
|
ext_mtx sync.Mutex
|
|
ext_svcs []Service
|
|
|
|
pxy_addr []string
|
|
pxy_ws *server_proxy_ssh_ws
|
|
pxy_mux *http.ServeMux
|
|
pxy []*http.Server // proxy server
|
|
|
|
wpx_addr []string
|
|
wpx_ws *server_proxy_ssh_ws
|
|
wpx_mux *http.ServeMux
|
|
wpx []*http.Server // proxy server than handles http/https only
|
|
|
|
ctl_addr []string
|
|
ctl_prefix string
|
|
ctl_mux *http.ServeMux
|
|
ctl []*http.Server // control server
|
|
|
|
rpc []*net.TCPListener // main listener for grpc
|
|
rpc_wg sync.WaitGroup
|
|
rpc_svr *grpc.Server
|
|
|
|
pts_limit int // global pts limit
|
|
cts_limit int
|
|
cts_next_id ConnId
|
|
cts_mtx sync.Mutex
|
|
cts_map ServerConnMap
|
|
cts_map_by_addr ServerConnMapByAddr
|
|
cts_wg sync.WaitGroup
|
|
|
|
log Logger
|
|
|
|
svc_port_mtx sync.Mutex
|
|
svc_port_map ServerSvcPortMap
|
|
|
|
stats struct {
|
|
conns atomic.Int64
|
|
routes atomic.Int64
|
|
peers atomic.Int64
|
|
ssh_proxy_sessions atomic.Int64
|
|
}
|
|
|
|
wpx_resp_tf ServerWpxResponseTransformer
|
|
wpx_foreign_port_proxy_maker ServerWpxForeignPortProxyMaker
|
|
xterm_html string
|
|
|
|
UnimplementedHoduServer
|
|
}
|
|
|
|
// connection from client.
|
|
// client connect to the server, the server accept it, and makes a tunnel request
|
|
type ServerConn struct {
|
|
svr *Server
|
|
Id ConnId
|
|
sid string // for logging
|
|
|
|
RemoteAddr net.Addr // client address that created this structure
|
|
LocalAddr net.Addr // local address that the client is connected to
|
|
pss *GuardedPacketStreamServer
|
|
|
|
route_mtx sync.Mutex
|
|
route_map ServerRouteMap
|
|
route_wg sync.WaitGroup
|
|
|
|
wg sync.WaitGroup
|
|
stop_req atomic.Bool
|
|
stop_chan chan bool
|
|
}
|
|
|
|
type ServerRoute struct {
|
|
Cts *ServerConn
|
|
Id RouteId
|
|
|
|
svc_l *net.TCPListener
|
|
SvcAddr *net.TCPAddr // actual listening address
|
|
SvcReqAddr string
|
|
SvcPermNet netip.Prefix // network from which access is allowed
|
|
SvcOption RouteOption
|
|
|
|
PtcAddr string
|
|
PtcName string
|
|
|
|
pts_mtx sync.Mutex
|
|
pts_map ServerPeerConnMap
|
|
pts_limit int
|
|
pts_next_id PeerId
|
|
pts_wg sync.WaitGroup
|
|
stop_req atomic.Bool
|
|
}
|
|
|
|
type ServerPluginInterface interface {
|
|
ModifyResponse(w http.ResponseWriter, r *http.Request)
|
|
Init(server *Server)
|
|
Cleanup()
|
|
}
|
|
|
|
type GuardedPacketStreamServer struct {
|
|
mtx sync.Mutex
|
|
//pss Hodu_PacketStreamServer
|
|
Hodu_PacketStreamServer // let's embed it to avoid reimplement Recv() and Context()
|
|
}
|
|
|
|
// ------------------------------------
|
|
|
|
func (g *GuardedPacketStreamServer) Send(data *Packet) error {
|
|
// while Recv() on a stream is called from the same gorountine all the time,
|
|
// Send() is called from multiple places. let's guard it as grpc-go
|
|
// doesn't provide concurrency safety in this case.
|
|
// https://github.com/grpc/grpc-go/blob/master/Documentation/concurrency.md
|
|
g.mtx.Lock()
|
|
defer g.mtx.Unlock()
|
|
return g.Hodu_PacketStreamServer.Send(data)
|
|
}
|
|
|
|
/*
|
|
func (g *GuardedPacketStreamServer) Recv() (*Packet, error) {
|
|
return g.pss.Recv()
|
|
}
|
|
|
|
func (g *GuardedPacketStreamServer) Context() context.Context {
|
|
return g.pss.Context()
|
|
}*/
|
|
|
|
// ------------------------------------
|
|
|
|
func NewServerRoute(cts *ServerConn, id RouteId, option RouteOption, ptc_addr string, ptc_name string, svc_requested_addr string, svc_permitted_net string) (*ServerRoute, error) {
|
|
var r ServerRoute
|
|
var l *net.TCPListener
|
|
var svcaddr *net.TCPAddr
|
|
var svcnet netip.Prefix
|
|
var err error
|
|
|
|
if svc_permitted_net != "" {
|
|
// parse the permitted network before creating a listener.
|
|
// the listener opened doesn't have to be closed when parsing fails.
|
|
svcnet, err = netip.ParsePrefix(svc_permitted_net)
|
|
if err != nil {
|
|
return nil , err
|
|
}
|
|
}
|
|
|
|
l, svcaddr, err = cts.make_route_listener(id, option, svc_requested_addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if svc_permitted_net == "" {
|
|
if svcaddr.IP.To4() != nil {
|
|
svcnet = IPV4_PREFIX_ZERO
|
|
} else {
|
|
svcnet = IPV6_PREFIX_ZERO
|
|
}
|
|
}
|
|
|
|
r.Cts = cts
|
|
r.Id = id
|
|
r.svc_l = l
|
|
r.SvcAddr = svcaddr
|
|
r.SvcReqAddr = svc_requested_addr
|
|
r.SvcPermNet = svcnet
|
|
r.SvcOption = option
|
|
|
|
r.PtcAddr = ptc_addr
|
|
r.PtcName = ptc_name
|
|
r.pts_limit = PTS_LIMIT
|
|
r.pts_map = make(ServerPeerConnMap)
|
|
r.pts_next_id = 1
|
|
r.stop_req.Store(false)
|
|
|
|
return &r, nil
|
|
}
|
|
|
|
func (r *ServerRoute) AddNewServerPeerConn(c *net.TCPConn) (*ServerPeerConn, error) {
|
|
var pts *ServerPeerConn
|
|
var ok bool
|
|
var start_id PeerId
|
|
var assigned_id PeerId
|
|
|
|
r.pts_mtx.Lock()
|
|
defer r.pts_mtx.Unlock()
|
|
|
|
if len(r.pts_map) >= r.pts_limit {
|
|
return nil, fmt.Errorf("peer-to-server connection table full")
|
|
}
|
|
|
|
start_id = r.pts_next_id
|
|
for {
|
|
_, ok = r.pts_map[r.pts_next_id]
|
|
if !ok {
|
|
assigned_id = r.pts_next_id
|
|
r.pts_next_id++
|
|
if r.pts_next_id == 0 { r.pts_next_id++ }
|
|
break
|
|
}
|
|
r.pts_next_id++
|
|
if r.pts_next_id == 0 { r.pts_next_id++ }
|
|
if r.pts_next_id == start_id {
|
|
// unlikely to happen but it cycled through the whole range.
|
|
return nil, fmt.Errorf("failed to assign peer-to-server connection id")
|
|
}
|
|
}
|
|
|
|
pts = NewServerPeerConn(r, c, assigned_id)
|
|
r.pts_map[pts.conn_id] = pts
|
|
r.Cts.svr.stats.peers.Add(1)
|
|
|
|
return pts, nil
|
|
}
|
|
|
|
func (r *ServerRoute) RemoveServerPeerConn(pts *ServerPeerConn) {
|
|
r.pts_mtx.Lock()
|
|
delete(r.pts_map, pts.conn_id)
|
|
r.Cts.svr.stats.peers.Add(-1)
|
|
r.pts_mtx.Unlock()
|
|
r.Cts.svr.log.Write(r.Cts.sid, LOG_DEBUG, "Removed server-side peer connection %s from route(%d)", pts.conn.RemoteAddr().String(), r.Id)
|
|
}
|
|
|
|
func (r *ServerRoute) RunTask(wg *sync.WaitGroup) {
|
|
var err error
|
|
var conn *net.TCPConn
|
|
var pts *ServerPeerConn
|
|
var raddr *net.TCPAddr
|
|
var iaddr netip.Addr
|
|
|
|
defer wg.Done()
|
|
|
|
for {
|
|
conn, err = r.svc_l.AcceptTCP() // this call is blocking...
|
|
if err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
r.Cts.svr.log.Write(r.Cts.sid, LOG_INFO, "Server-side peer listener closed on route(%d)", r.Id)
|
|
} else {
|
|
r.Cts.svr.log.Write(r.Cts.sid, LOG_INFO, "Server-side peer listener error on route(%d) - %s", r.Id, err.Error())
|
|
}
|
|
break
|
|
}
|
|
|
|
raddr = conn.RemoteAddr().(*net.TCPAddr)
|
|
iaddr, _ = netip.AddrFromSlice(raddr.IP)
|
|
|
|
if !r.SvcPermNet.Contains(iaddr) {
|
|
r.Cts.svr.log.Write(r.Cts.sid, LOG_DEBUG, "Rejected server-side peer %s to route(%d) - allowed range %v", raddr.String(), r.Id, r.SvcPermNet)
|
|
conn.Close()
|
|
}
|
|
|
|
if r.Cts.svr.pts_limit > 0 && int(r.Cts.svr.stats.peers.Load()) >= r.Cts.svr.pts_limit {
|
|
r.Cts.svr.log.Write(r.Cts.sid, LOG_DEBUG, "Rejected server-side peer %s to route(%d) - allowed max %d", raddr.String(), r.Id, r.Cts.svr.pts_limit)
|
|
conn.Close()
|
|
}
|
|
|
|
pts, err = r.AddNewServerPeerConn(conn)
|
|
if err != nil {
|
|
r.Cts.svr.log.Write(r.Cts.sid, LOG_ERROR, "Failed to add server-side peer %s to route(%d) - %s", r.Id, raddr.String(), r.Id, err.Error())
|
|
conn.Close()
|
|
} else {
|
|
r.Cts.svr.log.Write(r.Cts.sid, LOG_DEBUG, "Added server-side peer %s to route(%d)", raddr.String(), r.Id)
|
|
r.pts_wg.Add(1)
|
|
go pts.RunTask(&r.pts_wg)
|
|
}
|
|
}
|
|
|
|
r.ReqStop()
|
|
|
|
r.pts_wg.Wait()
|
|
r.Cts.svr.log.Write(r.Cts.sid, LOG_DEBUG, "All service-side peer handlers ended on route(%d)", r.Id)
|
|
|
|
r.Cts.RemoveServerRoute(r) // final phase...
|
|
}
|
|
|
|
func (r *ServerRoute) ReqStop() {
|
|
if r.stop_req.CompareAndSwap(false, true) {
|
|
var pts *ServerPeerConn
|
|
|
|
r.pts_mtx.Lock()
|
|
for _, pts = range r.pts_map { pts.ReqStop() }
|
|
r.pts_mtx.Unlock()
|
|
|
|
r.svc_l.Close()
|
|
}
|
|
}
|
|
|
|
func (r *ServerRoute) ReportEvent(pts_id PeerId, event_type PACKET_KIND, event_data interface{}) error {
|
|
var spc *ServerPeerConn
|
|
var ok bool
|
|
|
|
r.pts_mtx.Lock()
|
|
spc, ok = r.pts_map[pts_id]
|
|
if !ok {
|
|
r.pts_mtx.Unlock()
|
|
return fmt.Errorf("non-existent peer id - %d", pts_id)
|
|
}
|
|
r.pts_mtx.Unlock()
|
|
|
|
return spc.ReportEvent(event_type, event_data)
|
|
}
|
|
// ------------------------------------
|
|
|
|
func (cts *ServerConn) make_route_listener(id RouteId, option RouteOption, svc_requested_addr string) (*net.TCPListener, *net.TCPAddr, error) {
|
|
var l *net.TCPListener
|
|
var svcaddr *net.TCPAddr
|
|
var nw string
|
|
var prev_cri ConnRouteId
|
|
var ok bool
|
|
var err error
|
|
|
|
if svc_requested_addr != "" {
|
|
var ap netip.AddrPort
|
|
|
|
ap, err = netip.ParseAddrPort(svc_requested_addr)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("invalid service address %s - %s", svc_requested_addr, err.Error())
|
|
}
|
|
|
|
svcaddr = &net.TCPAddr{IP: ap.Addr().AsSlice(), Port: int(ap.Port())}
|
|
}
|
|
|
|
if option & RouteOption(ROUTE_OPTION_TCP) != 0 {
|
|
nw = "tcp"
|
|
if svcaddr == nil {
|
|
svcaddr = &net.TCPAddr{Port: 0} // port 0 for automatic assignment.
|
|
}
|
|
} else if option & RouteOption(ROUTE_OPTION_TCP4) != 0 {
|
|
nw = "tcp4"
|
|
if svcaddr == nil {
|
|
svcaddr = &net.TCPAddr{IP: net.IPv4zero, Port: 0} // port 0 for automatic assignment.
|
|
}
|
|
} else if option & RouteOption(ROUTE_OPTION_TCP6) != 0 {
|
|
nw = "tcp6"
|
|
if svcaddr == nil {
|
|
svcaddr = &net.TCPAddr{IP: net.IPv6zero, Port: 0} // port 0 for automatic assignment.
|
|
}
|
|
} else {
|
|
return nil, nil, fmt.Errorf("invalid route option value %d(%s)", option, option.string())
|
|
}
|
|
|
|
l, err = net.ListenTCP(nw, svcaddr) // make the binding address configurable. support multiple binding addresses???
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
svcaddr = l.Addr().(*net.TCPAddr)
|
|
|
|
cts.svr.svc_port_mtx.Lock()
|
|
prev_cri, ok = cts.svr.svc_port_map[PortId(svcaddr.Port)]
|
|
if ok {
|
|
cts.svr.log.Write(cts.sid, LOG_ERROR,
|
|
"Route(%d,%d) on %s not unique by port number - existing route(%d,%d)",
|
|
cts.Id, id, prev_cri.conn_id, prev_cri.route_id, svcaddr.String())
|
|
l.Close()
|
|
return nil, nil, err
|
|
}
|
|
cts.svr.svc_port_map[PortId(svcaddr.Port)] = ConnRouteId{conn_id: cts.Id, route_id: id}
|
|
cts.svr.svc_port_mtx.Unlock()
|
|
|
|
cts.svr.log.Write(cts.sid, LOG_DEBUG, "Route(%d,%d) listening on %s", cts.Id, id, svcaddr.String())
|
|
return l, svcaddr, nil
|
|
}
|
|
|
|
func (cts *ServerConn) AddNewServerRoute(route_id RouteId, proto RouteOption, ptc_addr string, ptc_name string, svc_requested_addr string, svc_permitted_net string) (*ServerRoute, error) {
|
|
var r *ServerRoute
|
|
var err error
|
|
|
|
cts.route_mtx.Lock()
|
|
if cts.route_map[route_id] != nil {
|
|
// If this happens, something must be wrong between the server and the client
|
|
// most likely, it must be a logic error. the state must not go out of sync
|
|
// as the route_id and the peer_id are supposed to be the same between the client
|
|
// and the server.
|
|
cts.route_mtx.Unlock()
|
|
return nil, fmt.Errorf("existent route id - %d", route_id)
|
|
}
|
|
r, err = NewServerRoute(cts, route_id, proto, ptc_addr, ptc_name, svc_requested_addr, svc_permitted_net)
|
|
if err != nil {
|
|
cts.route_mtx.Unlock()
|
|
return nil, err
|
|
}
|
|
cts.route_map[route_id] = r
|
|
cts.svr.stats.routes.Add(1)
|
|
cts.route_mtx.Unlock()
|
|
|
|
cts.route_wg.Add(1)
|
|
go r.RunTask(&cts.route_wg)
|
|
return r, nil
|
|
}
|
|
|
|
func (cts *ServerConn) RemoveServerRoute(route *ServerRoute) error {
|
|
var r *ServerRoute
|
|
var ok bool
|
|
|
|
cts.route_mtx.Lock()
|
|
r, ok = cts.route_map[route.Id]
|
|
if !ok {
|
|
cts.route_mtx.Unlock()
|
|
return fmt.Errorf("non-existent route id - %d", route.Id)
|
|
}
|
|
if r != route {
|
|
cts.route_mtx.Unlock()
|
|
return fmt.Errorf("non-existent route - %d", route.Id)
|
|
}
|
|
delete(cts.route_map, route.Id)
|
|
cts.svr.stats.routes.Add(-1)
|
|
cts.route_mtx.Unlock()
|
|
|
|
cts.svr.svc_port_mtx.Lock()
|
|
delete(cts.svr.svc_port_map, PortId(route.SvcAddr.Port))
|
|
cts.svr.svc_port_mtx.Unlock()
|
|
|
|
r.ReqStop()
|
|
return nil
|
|
}
|
|
|
|
func (cts *ServerConn) RemoveServerRouteById(route_id RouteId) (*ServerRoute, error) {
|
|
var r *ServerRoute
|
|
var ok bool
|
|
|
|
cts.route_mtx.Lock()
|
|
r, ok = cts.route_map[route_id]
|
|
if !ok {
|
|
cts.route_mtx.Unlock()
|
|
return nil, fmt.Errorf("non-existent route id - %d", route_id)
|
|
}
|
|
delete(cts.route_map, route_id)
|
|
cts.svr.stats.routes.Add(-1)
|
|
cts.route_mtx.Unlock()
|
|
|
|
cts.svr.svc_port_mtx.Lock()
|
|
delete(cts.svr.svc_port_map, PortId(r.SvcAddr.Port))
|
|
cts.svr.svc_port_mtx.Unlock()
|
|
|
|
r.ReqStop()
|
|
return r, nil
|
|
}
|
|
|
|
func (cts *ServerConn) FindServerRouteById(route_id RouteId) *ServerRoute {
|
|
var r *ServerRoute
|
|
var ok bool
|
|
|
|
cts.route_mtx.Lock()
|
|
r, ok = cts.route_map[route_id]
|
|
if !ok {
|
|
cts.route_mtx.Unlock()
|
|
return nil
|
|
}
|
|
cts.route_mtx.Unlock()
|
|
|
|
return r
|
|
}
|
|
|
|
func (cts *ServerConn) ReqStopAllServerRoutes() {
|
|
var r *ServerRoute
|
|
|
|
cts.route_mtx.Lock()
|
|
for _, r = range cts.route_map { r.ReqStop() }
|
|
cts.route_mtx.Unlock()
|
|
}
|
|
|
|
func (cts *ServerConn) ReportEvent(route_id RouteId, pts_id PeerId, event_type PACKET_KIND, event_data interface{}) error {
|
|
var r *ServerRoute
|
|
var ok bool
|
|
|
|
cts.route_mtx.Lock()
|
|
r, ok = cts.route_map[route_id]
|
|
if !ok {
|
|
cts.route_mtx.Unlock()
|
|
return fmt.Errorf("non-existent route id - %d", route_id)
|
|
}
|
|
cts.route_mtx.Unlock()
|
|
|
|
return r.ReportEvent(pts_id, event_type, event_data)
|
|
}
|
|
|
|
func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) {
|
|
var pkt *Packet
|
|
var err error
|
|
|
|
defer wg.Done()
|
|
|
|
for {
|
|
pkt, err = cts.pss.Recv()
|
|
if errors.Is(err, io.EOF) {
|
|
cts.svr.log.Write(cts.sid, LOG_INFO, "RPC stream closed for client %s", cts.RemoteAddr)
|
|
goto done
|
|
}
|
|
if err != nil {
|
|
cts.svr.log.Write(cts.sid, LOG_ERROR, "RPC stream error for client %s - %s", cts.RemoteAddr, err.Error())
|
|
goto done
|
|
}
|
|
|
|
switch pkt.Kind {
|
|
case PACKET_KIND_ROUTE_START:
|
|
var x *Packet_Route
|
|
var ok bool
|
|
x, ok = pkt.U.(*Packet_Route)
|
|
if ok {
|
|
var r *ServerRoute
|
|
|
|
r, err = cts.AddNewServerRoute(RouteId(x.Route.RouteId), RouteOption(x.Route.ServiceOption), x.Route.TargetAddrStr, x.Route.TargetName, x.Route.ServiceAddrStr, x.Route.ServiceNetStr)
|
|
if err != nil {
|
|
cts.svr.log.Write(cts.sid, LOG_ERROR,
|
|
"Failed to add route(%d,%s) for %s - %s",
|
|
x.Route.RouteId, x.Route.TargetAddrStr, cts.RemoteAddr, err.Error())
|
|
|
|
err = cts.pss.Send(MakeRouteStoppedPacket(RouteId(x.Route.RouteId), RouteOption(x.Route.ServiceOption), x.Route.TargetAddrStr, x.Route.TargetName, x.Route.ServiceAddrStr, x.Route.ServiceNetStr))
|
|
if err != nil {
|
|
cts.svr.log.Write(cts.sid, LOG_ERROR,
|
|
"Failed to send route_stopped event(%d,%s,%v,%s) to client %s - %s",
|
|
x.Route.RouteId, x.Route.TargetAddrStr, x.Route.ServiceOption, x.Route.ServiceNetStr, cts.RemoteAddr, err.Error())
|
|
goto done
|
|
} else {
|
|
cts.svr.log.Write(cts.sid, LOG_DEBUG,
|
|
"Sent route_stopped event(%d,%s,%v,%s) to client %s",
|
|
x.Route.RouteId, x.Route.TargetAddrStr, x.Route.ServiceOption, x.Route.ServiceNetStr, cts.RemoteAddr)
|
|
}
|
|
|
|
} else {
|
|
cts.svr.log.Write(cts.sid, LOG_INFO,
|
|
"Added route(%d,%s,%s,%v,%v) for client %s to cts(%d)",
|
|
r.Id, r.PtcAddr, r.SvcAddr.String(), r.SvcOption, r.SvcPermNet, cts.RemoteAddr, cts.Id)
|
|
err = cts.pss.Send(MakeRouteStartedPacket(r.Id, r.SvcOption, r.SvcAddr.String(), r.PtcName, r.SvcReqAddr, r.SvcPermNet.String()))
|
|
if err != nil {
|
|
r.ReqStop()
|
|
cts.svr.log.Write(cts.sid, LOG_ERROR,
|
|
"Failed to send route_started event(%d,%s,%s,%s%v,%v) to client %s - %s",
|
|
r.Id, r.PtcAddr, r.SvcAddr.String(), r.SvcOption, r.SvcPermNet, cts.RemoteAddr, err.Error())
|
|
goto done
|
|
}
|
|
}
|
|
} else {
|
|
cts.svr.log.Write(cts.sid, LOG_INFO, "Received invalid packet from %s", cts.RemoteAddr)
|
|
// TODO: need to abort this client?
|
|
}
|
|
|
|
case PACKET_KIND_ROUTE_STOP:
|
|
var x *Packet_Route
|
|
var ok bool
|
|
x, ok = pkt.U.(*Packet_Route)
|
|
if ok {
|
|
var r *ServerRoute
|
|
|
|
r, err = cts.RemoveServerRouteById(RouteId(x.Route.RouteId))
|
|
if err != nil {
|
|
cts.svr.log.Write(cts.sid, LOG_ERROR,
|
|
"Failed to delete route(%d,%s) for client %s - %s",
|
|
x.Route.RouteId, x.Route.TargetAddrStr, cts.RemoteAddr, err.Error())
|
|
} else {
|
|
cts.svr.log.Write(cts.sid, LOG_ERROR,
|
|
"Deleted route(%d,%s,%s,%v,%v) for client %s",
|
|
r.Id, r.PtcAddr, r.SvcAddr.String(), r.SvcOption, r.SvcPermNet.String(), cts.RemoteAddr)
|
|
err = cts.pss.Send(MakeRouteStoppedPacket(r.Id, r.SvcOption, r.PtcAddr, r.PtcName, r.SvcReqAddr, r.SvcPermNet.String()))
|
|
if err != nil {
|
|
r.ReqStop()
|
|
cts.svr.log.Write(cts.sid, LOG_ERROR,
|
|
"Failed to send route_stopped event(%d,%s,%s,%v.%v) to client %s - %s",
|
|
r.Id, r.PtcAddr, r.SvcAddr.String(), r.SvcOption, r.SvcPermNet.String(), cts.RemoteAddr, err.Error())
|
|
goto done
|
|
}
|
|
}
|
|
} else {
|
|
cts.svr.log.Write(cts.sid, LOG_ERROR, "Invalid route_stop event from %s", cts.RemoteAddr)
|
|
}
|
|
|
|
case PACKET_KIND_PEER_STARTED:
|
|
// the connection from the client to a peer has been established
|
|
var x *Packet_Peer
|
|
var ok bool
|
|
x, ok = pkt.U.(*Packet_Peer)
|
|
if ok {
|
|
err = cts.ReportEvent(RouteId(x.Peer.RouteId), PeerId(x.Peer.PeerId), PACKET_KIND_PEER_STARTED, x.Peer)
|
|
if err != nil {
|
|
cts.svr.log.Write(cts.sid, LOG_ERROR,
|
|
"Failed to handle peer_started event from %s for peer(%d,%d,%s,%s) - %s",
|
|
cts.RemoteAddr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr, err.Error())
|
|
} else {
|
|
cts.svr.log.Write(cts.sid, LOG_DEBUG,
|
|
"Handled peer_started event from %s for peer(%d,%d,%s,%s)",
|
|
cts.RemoteAddr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr)
|
|
}
|
|
} else {
|
|
// invalid event data
|
|
cts.svr.log.Write(cts.sid, LOG_ERROR, "Invalid peer_started event from %s", cts.RemoteAddr)
|
|
}
|
|
|
|
case PACKET_KIND_PEER_ABORTED:
|
|
var x *Packet_Peer
|
|
var ok bool
|
|
x, ok = pkt.U.(*Packet_Peer)
|
|
if ok {
|
|
err = cts.ReportEvent(RouteId(x.Peer.RouteId), PeerId(x.Peer.PeerId), PACKET_KIND_PEER_ABORTED, x.Peer)
|
|
if err != nil {
|
|
cts.svr.log.Write(cts.sid, LOG_ERROR,
|
|
"Failed to handle peer_aborted event from %s for peer(%d,%d,%s,%s) - %s",
|
|
cts.RemoteAddr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr, err.Error())
|
|
} else {
|
|
cts.svr.log.Write(cts.sid, LOG_DEBUG,
|
|
"Handled peer_aborted event from %s for peer(%d,%d,%s,%s)",
|
|
cts.RemoteAddr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr)
|
|
}
|
|
} else {
|
|
// invalid event data
|
|
cts.svr.log.Write(cts.sid, LOG_ERROR, "Invalid peer_aborted event from %s", cts.RemoteAddr)
|
|
}
|
|
|
|
case PACKET_KIND_PEER_STOPPED:
|
|
// the connection from the client to a peer has been established
|
|
var x *Packet_Peer
|
|
var ok bool
|
|
x, ok = pkt.U.(*Packet_Peer)
|
|
if ok {
|
|
err = cts.ReportEvent(RouteId(x.Peer.RouteId), PeerId(x.Peer.PeerId), PACKET_KIND_PEER_STOPPED, x.Peer)
|
|
if err != nil {
|
|
cts.svr.log.Write(cts.sid, LOG_ERROR,
|
|
"Failed to handle peer_stopped event from %s for peer(%d,%d,%s,%s) - %s",
|
|
cts.RemoteAddr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr, err.Error())
|
|
} else {
|
|
cts.svr.log.Write(cts.sid, LOG_DEBUG,
|
|
"Handled peer_stopped event from %s for peer(%d,%d,%s,%s)",
|
|
cts.RemoteAddr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr)
|
|
}
|
|
} else {
|
|
// invalid event data
|
|
cts.svr.log.Write(cts.sid, LOG_ERROR, "Invalid peer_stopped event from %s", cts.RemoteAddr)
|
|
}
|
|
|
|
case PACKET_KIND_PEER_DATA:
|
|
// the connection from the client to a peer has been established
|
|
var x *Packet_Data
|
|
var ok bool
|
|
x, ok = pkt.U.(*Packet_Data)
|
|
if ok {
|
|
err = cts.ReportEvent(RouteId(x.Data.RouteId), PeerId(x.Data.PeerId), PACKET_KIND_PEER_DATA, x.Data.Data)
|
|
if err != nil {
|
|
cts.svr.log.Write(cts.sid, LOG_ERROR,
|
|
"Failed to handle peer_data event from %s for peer(%d,%d) - %s",
|
|
cts.RemoteAddr, x.Data.RouteId, x.Data.PeerId, err.Error())
|
|
} else {
|
|
cts.svr.log.Write(cts.sid, LOG_DEBUG,
|
|
"Handled peer_data event from %s for peer(%d,%d)",
|
|
cts.RemoteAddr, x.Data.RouteId, x.Data.PeerId)
|
|
}
|
|
} else {
|
|
// invalid event data
|
|
cts.svr.log.Write(cts.sid, LOG_ERROR, "Invalid peer_data event from %s", cts.RemoteAddr)
|
|
}
|
|
}
|
|
}
|
|
|
|
done:
|
|
cts.svr.log.Write(cts.sid, LOG_INFO, "RPC stream receiver ended")
|
|
}
|
|
|
|
func (cts *ServerConn) RunTask(wg *sync.WaitGroup) {
|
|
var strm *GuardedPacketStreamServer
|
|
var ctx context.Context
|
|
|
|
defer wg.Done()
|
|
|
|
strm = cts.pss
|
|
ctx = strm.Context()
|
|
|
|
// it looks like the only proper way to interrupt the blocking Recv
|
|
// call on the grpc streaming server is exit from the service handler
|
|
// which is this function invoked from PacketStream().
|
|
// there is no cancel function or whatever that can interrupt it.
|
|
// so start the Recv() loop in a separte goroutine and let this
|
|
// function be the channel waiter only.
|
|
// increment on the wait group is for the caller to wait for
|
|
// these detached goroutines to finish.
|
|
wg.Add(1)
|
|
go cts.receive_from_stream(wg)
|
|
|
|
for {
|
|
// exit if context is done
|
|
// or continue
|
|
select {
|
|
case <-ctx.Done(): // the stream context is done
|
|
cts.svr.log.Write(cts.sid, LOG_INFO, "RPC stream done - %s", ctx.Err().Error())
|
|
goto done
|
|
|
|
case <- cts.stop_chan:
|
|
// get out of the loop to eventually to exit from
|
|
// this handler to let the main grpc server to
|
|
// close this specific client connection.
|
|
goto done
|
|
|
|
//default:
|
|
// no other case is ready.
|
|
// without the default case, the select construct would block
|
|
}
|
|
}
|
|
|
|
done:
|
|
cts.ReqStop() // just in case
|
|
cts.route_wg.Wait()
|
|
}
|
|
|
|
func (cts *ServerConn) ReqStop() {
|
|
if cts.stop_req.CompareAndSwap(false, true) {
|
|
var r *ServerRoute
|
|
|
|
cts.route_mtx.Lock()
|
|
for _, r = range cts.route_map { r.ReqStop() }
|
|
cts.route_mtx.Unlock()
|
|
|
|
// there is no good way to break a specific connection client to
|
|
// the grpc server. while the global grpc server is closed in
|
|
// ReqStop() for Server, the individuation connection is closed
|
|
// by returing from the grpc handler goroutine. See the comment
|
|
// RunTask() for ServerConn.
|
|
cts.stop_chan <- true
|
|
}
|
|
}
|
|
|
|
// --------------------------------------------------------------------
|
|
|
|
func (s *Server) GetSeed(ctx context.Context, c_seed *Seed) (*Seed, error) {
|
|
var s_seed Seed
|
|
|
|
// seed exchange is for furture expansion of the protocol
|
|
// there is nothing to do much about it for now.
|
|
|
|
s_seed.Version = HODU_RPC_VERSION
|
|
s_seed.Flags = 0
|
|
|
|
// we create no ServerConn structure associated with the connection
|
|
// at this phase for the server. it doesn't track the client version and
|
|
// features. we delegate protocol selection solely to the client.
|
|
|
|
return &s_seed, nil
|
|
}
|
|
|
|
func (s *Server) PacketStream(strm Hodu_PacketStreamServer) error {
|
|
var ctx context.Context
|
|
var p *peer.Peer
|
|
var ok bool
|
|
var err error
|
|
var cts *ServerConn
|
|
|
|
ctx = strm.Context()
|
|
p, ok = peer.FromContext(ctx)
|
|
if !ok {
|
|
return fmt.Errorf("failed to get peer from packet stream context")
|
|
}
|
|
|
|
cts, err = s.AddNewServerConn(&p.Addr, &p.LocalAddr, strm)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to add client %s - %s", p.Addr.String(), err.Error())
|
|
}
|
|
|
|
// Don't detached the cts task as a go-routine as this function
|
|
// is invoked as a go-routine by the grpc server.
|
|
s.cts_wg.Add(1)
|
|
cts.RunTask(&s.cts_wg)
|
|
return nil
|
|
}
|
|
|
|
// ------------------------------------
|
|
|
|
type ConnCatcher struct {
|
|
server *Server
|
|
}
|
|
|
|
func (cc *ConnCatcher) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
|
|
return ctx
|
|
}
|
|
|
|
func (cc *ConnCatcher) HandleRPC(ctx context.Context, s stats.RPCStats) {
|
|
}
|
|
|
|
func (cc *ConnCatcher) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
|
|
return ctx
|
|
//return context.TODO()
|
|
}
|
|
|
|
func (cc *ConnCatcher) HandleConn(ctx context.Context, cs stats.ConnStats) {
|
|
var p *peer.Peer
|
|
var ok bool
|
|
var addr string
|
|
|
|
p, ok = peer.FromContext(ctx)
|
|
if !ok {
|
|
addr = ""
|
|
} else {
|
|
addr = p.Addr.String()
|
|
}
|
|
|
|
/*
|
|
md,ok:=metadata.FromIncomingContext(ctx)
|
|
if ok {
|
|
}*/
|
|
|
|
switch cs.(type) {
|
|
case *stats.ConnBegin:
|
|
cc.server.log.Write("", LOG_INFO, "Client connected - %s", addr)
|
|
|
|
case *stats.ConnEnd:
|
|
var cts *ServerConn
|
|
var log_id string
|
|
cts, _ = cc.server.RemoveServerConnByAddr(p.Addr)
|
|
if cts != nil { log_id = cts.sid }
|
|
cc.server.log.Write(log_id, LOG_INFO, "Client disconnected - %s", addr)
|
|
}
|
|
}
|
|
|
|
// ------------------------------------
|
|
|
|
type wrappedStream struct {
|
|
grpc.ServerStream
|
|
}
|
|
|
|
func (w *wrappedStream) RecvMsg(msg interface{}) error {
|
|
return w.ServerStream.RecvMsg(msg)
|
|
}
|
|
|
|
func (w *wrappedStream) SendMsg(msg interface{}) error {
|
|
return w.ServerStream.SendMsg(msg)
|
|
}
|
|
|
|
func newWrappedStream(s grpc.ServerStream) grpc.ServerStream {
|
|
return &wrappedStream{s}
|
|
}
|
|
|
|
func streamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
|
var err error
|
|
|
|
// authentication (token verification)
|
|
/*
|
|
md, ok = metadata.FromIncomingContext(ss.Context())
|
|
if !ok {
|
|
return errMissingMetadata
|
|
}
|
|
if !valid(md["authorization"]) {
|
|
return errInvalidToken
|
|
}
|
|
*/
|
|
|
|
err = handler(srv, newWrappedStream(ss))
|
|
if err != nil {
|
|
// TODO: LOGGING
|
|
}
|
|
return err
|
|
}
|
|
|
|
func unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
|
var v interface{}
|
|
var err error
|
|
|
|
// authentication (token verification)
|
|
/*
|
|
md, ok = metadata.FromIncomingContext(ctx)
|
|
if !ok {
|
|
return nil, errMissingMetadata
|
|
}
|
|
if !valid(md["authorization"]) {
|
|
// return nil, errInvalidToken
|
|
}
|
|
*/
|
|
|
|
v, err = handler(ctx, req)
|
|
if err != nil {
|
|
//fmt.Printf("RPC failed with error: %v\n", err)
|
|
// TODO: Logging?
|
|
}
|
|
|
|
return v, err
|
|
}
|
|
|
|
|
|
type server_http_log_writer struct {
|
|
svr *Server
|
|
}
|
|
|
|
func (hlw *server_http_log_writer) Write(p []byte) (n int, err error) {
|
|
// the standard http.Server always requires *log.Logger
|
|
// use this iowriter to create a logger to pass it to the http server.
|
|
// since this is another log write wrapper, give adjustment value
|
|
hlw.svr.log.WriteWithCallDepth("", LOG_INFO, +1, string(p))
|
|
return len(p), nil
|
|
}
|
|
|
|
type ServerHttpHandler interface {
|
|
GetId() string
|
|
ServeHTTP (w http.ResponseWriter, req *http.Request) (int, error)
|
|
}
|
|
|
|
func (s *Server) wrap_http_handler(handler ServerHttpHandler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
var status_code int
|
|
var err error
|
|
var start_time time.Time
|
|
var time_taken time.Duration
|
|
|
|
// this deferred function is to overcome the recovering implemenation
|
|
// from panic done in go's http server. in that implemenation, panic
|
|
// is isolated to a single gorountine. however, i want this program
|
|
// to exit immediately once a panic condition is caught. (e.g. nil
|
|
// pointer dererence)
|
|
defer func() {
|
|
var err interface{} = recover()
|
|
if err != nil { dump_call_frame_and_exit(s.log, req, err) }
|
|
}()
|
|
|
|
start_time = time.Now()
|
|
status_code, err = handler.ServeHTTP(w, req)
|
|
time_taken = time.Now().Sub(start_time)
|
|
|
|
if status_code > 0 {
|
|
if err == nil {
|
|
s.log.Write(handler.GetId(), LOG_INFO, "[%s] %s %s %d %.9f", req.RemoteAddr, req.Method, req.URL.String(), status_code, time_taken.Seconds())
|
|
} else {
|
|
s.log.Write(handler.GetId(), LOG_INFO, "[%s] %s %s %d %.9f - %s", req.RemoteAddr, req.Method, req.URL.String(), status_code, time_taken.Seconds(), err.Error())
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, wpx_addrs []string, ctl_prefix string, ctltlscfg *tls.Config, rpctlscfg *tls.Config, pxytlscfg *tls.Config, wpxtlscfg *tls.Config, rpc_max int, peer_max int) (*Server, error) {
|
|
var s Server
|
|
var l *net.TCPListener
|
|
var rpcaddr *net.TCPAddr
|
|
var addr string
|
|
var gl *net.TCPListener
|
|
var i int
|
|
var hs_log *log.Logger
|
|
var opts []grpc.ServerOption
|
|
var err error
|
|
|
|
if len(rpc_addrs) <= 0 {
|
|
return nil, fmt.Errorf("no server addresses provided")
|
|
}
|
|
|
|
s.ctx, s.ctx_cancel = context.WithCancel(ctx)
|
|
s.log = logger
|
|
/* create the specified number of listeners */
|
|
s.rpc = make([]*net.TCPListener, 0)
|
|
for _, addr = range rpc_addrs {
|
|
rpcaddr, err = net.ResolveTCPAddr("tcp", addr) // Make this interruptable???
|
|
if err != nil {
|
|
goto oops
|
|
}
|
|
|
|
l, err = net.ListenTCP("tcp", rpcaddr)
|
|
if err != nil {
|
|
goto oops
|
|
}
|
|
|
|
s.rpc = append(s.rpc, l)
|
|
}
|
|
|
|
s.ctltlscfg = ctltlscfg
|
|
s.rpctlscfg = rpctlscfg
|
|
s.pxytlscfg = pxytlscfg
|
|
s.wpxtlscfg = wpxtlscfg
|
|
s.ext_svcs = make([]Service, 0, 1)
|
|
s.pts_limit = peer_max
|
|
s.cts_limit = rpc_max
|
|
s.cts_next_id = 1
|
|
s.cts_map = make(ServerConnMap)
|
|
s.cts_map_by_addr = make(ServerConnMapByAddr)
|
|
s.svc_port_map = make(ServerSvcPortMap)
|
|
s.stop_chan = make(chan bool, 8)
|
|
s.stop_req.Store(false)
|
|
|
|
/*
|
|
creds, err := credentials.NewServerTLSFromFile(data.Path("x509/server_cert.pem"), data.Path("x509/server_key.pem"))
|
|
if err != nil {
|
|
log.Fatalf("failed to create credentials: %v", err)
|
|
}
|
|
gs = grpc.NewServer(grpc.Creds(creds))
|
|
*/
|
|
|
|
opts = append(opts, grpc.StatsHandler(&ConnCatcher{server: &s}))
|
|
if s.rpctlscfg != nil { opts = append(opts, grpc.Creds(credentials.NewTLS(s.rpctlscfg))) }
|
|
//opts = append(opts, grpc.UnaryInterceptor(unaryInterceptor))
|
|
//opts = append(opts, grpc.StreamInterceptor(streamInterceptor))
|
|
s.rpc_svr = grpc.NewServer(opts...)
|
|
RegisterHoduServer(s.rpc_svr, &s)
|
|
|
|
// ---------------------------------------------------------
|
|
|
|
hs_log = log.New(&server_http_log_writer{svr: &s}, "", 0)
|
|
|
|
// ---------------------------------------------------------
|
|
|
|
s.ctl_prefix = ctl_prefix
|
|
s.ctl_mux = http.NewServeMux()
|
|
|
|
s.ctl_mux.Handle(s.ctl_prefix + "/_ctl/server-conns",
|
|
s.wrap_http_handler(&server_ctl_server_conns{server_ctl{s: &s, id: "ctl"}}))
|
|
s.ctl_mux.Handle(s.ctl_prefix + "/_ctl/server-conns/{conn_id}",
|
|
s.wrap_http_handler(&server_ctl_server_conns_id{server_ctl{s: &s, id: "ctl"}}))
|
|
s.ctl_mux.Handle(s.ctl_prefix + "/_ctl/server-conns/{conn_id}/routes",
|
|
s.wrap_http_handler(&server_ctl_server_conns_id_routes{server_ctl{s: &s, id: "ctl"}}))
|
|
s.ctl_mux.Handle(s.ctl_prefix + "/_ctl/server-conns/{conn_id}/routes/{route_id}",
|
|
s.wrap_http_handler(&server_ctl_server_conns_id_routes_id{server_ctl{s: &s, id: "ctl"}}))
|
|
s.ctl_mux.Handle(s.ctl_prefix + "/_ctl/stats",
|
|
s.wrap_http_handler(&server_ctl_stats{server_ctl{s: &s, id: "ctl"}}))
|
|
|
|
s.ctl_addr = make([]string, len(ctl_addrs))
|
|
s.ctl = make([]*http.Server, len(ctl_addrs))
|
|
copy(s.ctl_addr, ctl_addrs)
|
|
|
|
for i = 0; i < len(ctl_addrs); i++ {
|
|
s.ctl[i] = &http.Server{
|
|
Addr: ctl_addrs[i],
|
|
Handler: s.ctl_mux,
|
|
TLSConfig: s.ctltlscfg,
|
|
ErrorLog: hs_log,
|
|
// TODO: more settings
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------
|
|
|
|
s.pxy_ws = &server_proxy_ssh_ws{s: &s, id: "pxy-ws"}
|
|
s.pxy_mux = http.NewServeMux() // TODO: make /_init,_ssh,_ssh_ws,_http configurable...
|
|
s.pxy_mux.Handle("/_ssh-ws/{conn_id}/{route_id}",
|
|
websocket.Handler(func(ws *websocket.Conn) { s.pxy_ws.ServeWebsocket(ws) }))
|
|
s.pxy_mux.Handle("/_ssh/server-conns/{conn_id}/routes/{route_id}",
|
|
s.wrap_http_handler(&server_ctl_server_conns_id_routes_id{server_ctl{s: &s, id: "pxy"}}))
|
|
s.pxy_mux.Handle("/_ssh/{conn_id}/",
|
|
s.wrap_http_handler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: "pxy"}, file: "_redirect"}))
|
|
s.pxy_mux.Handle("/_ssh/{conn_id}/{route_id}/",
|
|
s.wrap_http_handler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: "pxy"}, file: "xterm.html"}))
|
|
s.pxy_mux.Handle("/_ssh/xterm.js",
|
|
s.wrap_http_handler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: "pxy"}, file: "xterm.js"}))
|
|
s.pxy_mux.Handle("/_ssh/xterm-addon-fit.js",
|
|
s.wrap_http_handler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: "pxy"}, file: "xterm-addon-fit.js"}))
|
|
s.pxy_mux.Handle("/_ssh/xterm.css",
|
|
s.wrap_http_handler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: "pxy"}, file: "xterm.css"}))
|
|
s.pxy_mux.Handle("/_ssh/",
|
|
s.wrap_http_handler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: "pxy"}, file: "_forbidden"}))
|
|
|
|
s.pxy_mux.Handle("/_http/{conn_id}/{route_id}/{trailer...}",
|
|
s.wrap_http_handler(&server_proxy_http_main{server_proxy: server_proxy{s: &s, id: "pxy"}, prefix: "/_http"}))
|
|
|
|
s.pxy_addr = make([]string, len(pxy_addrs))
|
|
s.pxy = make([]*http.Server, len(pxy_addrs))
|
|
copy(s.pxy_addr, pxy_addrs)
|
|
|
|
for i = 0; i < len(pxy_addrs); i++ {
|
|
s.pxy[i] = &http.Server{
|
|
Addr: pxy_addrs[i],
|
|
Handler: s.pxy_mux,
|
|
TLSConfig: s.pxytlscfg,
|
|
ErrorLog: hs_log,
|
|
// TODO: more settings
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------
|
|
|
|
s.wpx_mux = http.NewServeMux()
|
|
|
|
s.wpx_ws = &server_proxy_ssh_ws{s: &s, id: "wpx-ssh"}
|
|
s.wpx_mux = http.NewServeMux() // TODO: make /_init,_ssh,_ssh_ws,_http configurable...
|
|
s.wpx_mux.Handle("/_ssh-ws/{conn_id}/{route_id}",
|
|
websocket.Handler(func(ws *websocket.Conn) { s.wpx_ws.ServeWebsocket(ws) }))
|
|
|
|
s.wpx_mux.Handle("/_ssh/server-conns/{conn_id}/routes/{route_id}",
|
|
s.wrap_http_handler(&server_ctl_server_conns_id_routes_id{server_ctl{s: &s, id: "wpx"}}))
|
|
s.wpx_mux.Handle("/_ssh/xterm.js",
|
|
s.wrap_http_handler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: "wpx"}, file: "xterm.js"}))
|
|
s.wpx_mux.Handle("/_ssh/xterm-addon-fit.js",
|
|
s.wrap_http_handler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: "wpx"}, file: "xterm-addon-fit.js"}))
|
|
s.wpx_mux.Handle("/_ssh/xterm.css",
|
|
s.wrap_http_handler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: "wpx"}, file: "xterm.css"}))
|
|
s.wpx_mux.Handle("/_ssh/{port_id}",
|
|
s.wrap_http_handler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: "wpx"}, file: "xterm.html"}))
|
|
s.wpx_mux.Handle("/_ssh/",
|
|
s.wrap_http_handler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: "wpx"}, file: "_forbidden"}))
|
|
|
|
s.wpx_mux.Handle("/{port_id}/{trailer...}",
|
|
s.wrap_http_handler(&server_proxy_http_main{server_proxy: server_proxy{s: &s, id: "wpx"}, prefix: PORT_ID_MARKER}))
|
|
s.wpx_mux.Handle("/",
|
|
s.wrap_http_handler(&server_proxy_http_wpx{server_proxy: server_proxy{s: &s, id: "wpx"}}))
|
|
|
|
s.wpx_addr = make([]string, len(wpx_addrs))
|
|
s.wpx = make([]*http.Server, len(wpx_addrs))
|
|
copy(s.wpx_addr, wpx_addrs)
|
|
|
|
for i = 0; i < len(wpx_addrs); i++ {
|
|
s.wpx[i] = &http.Server{
|
|
Addr: wpx_addrs[i],
|
|
Handler: s.wpx_mux,
|
|
TLSConfig: s.wpxtlscfg,
|
|
ErrorLog: hs_log,
|
|
// TODO: more settings
|
|
}
|
|
}
|
|
// ---------------------------------------------------------
|
|
|
|
s.stats.conns.Store(0)
|
|
s.stats.routes.Store(0)
|
|
s.stats.peers.Store(0)
|
|
s.stats.ssh_proxy_sessions.Store(0)
|
|
|
|
return &s, nil
|
|
|
|
oops:
|
|
if gl != nil { gl.Close() }
|
|
for _, l = range s.rpc { l.Close() }
|
|
s.rpc = make([]*net.TCPListener, 0)
|
|
return nil, err
|
|
}
|
|
|
|
func (s *Server) SetWpxResponseTransformer(tf ServerWpxResponseTransformer) {
|
|
s.wpx_resp_tf = tf
|
|
}
|
|
|
|
func (s *Server) GetWpxResponseTransformer() ServerWpxResponseTransformer {
|
|
return s.wpx_resp_tf
|
|
}
|
|
|
|
func (s *Server) SetWpxForeignPortProxyMaker(pm ServerWpxForeignPortProxyMaker) {
|
|
s.wpx_foreign_port_proxy_maker = pm
|
|
}
|
|
|
|
func (s *Server) GetWpxForeignPortProxyMaker() ServerWpxForeignPortProxyMaker {
|
|
return s.wpx_foreign_port_proxy_maker
|
|
}
|
|
|
|
|
|
func (s *Server) SetXtermHtml(html string) {
|
|
s.xterm_html = html
|
|
}
|
|
|
|
func (s *Server) GetXtermHtml() string {
|
|
return s.xterm_html
|
|
}
|
|
|
|
func (s *Server) run_grpc_server(idx int, wg *sync.WaitGroup) error {
|
|
var l *net.TCPListener
|
|
var err error
|
|
|
|
defer wg.Done()
|
|
|
|
l = s.rpc[idx]
|
|
// it seems to be safe to call a single grpc server on differnt listening sockets multiple times
|
|
s.log.Write("", LOG_INFO, "Starting RPC server on %s", l.Addr().String())
|
|
err = s.rpc_svr.Serve(l)
|
|
if err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
s.log.Write("", LOG_INFO, "RPC server on %s closed", l.Addr().String())
|
|
} else {
|
|
s.log.Write("", LOG_ERROR, "Error from RPC server on %s - %s", l.Addr().String(), err.Error())
|
|
}
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) RunTask(wg *sync.WaitGroup) {
|
|
var idx int
|
|
|
|
defer wg.Done()
|
|
|
|
for idx, _ = range s.rpc {
|
|
s.rpc_wg.Add(1)
|
|
go s.run_grpc_server(idx, &s.rpc_wg)
|
|
}
|
|
|
|
// most the work is done by in separate goroutines (s.run_grp_server)
|
|
// this loop serves as a placeholder to prevent the logic flow from
|
|
// descening down to s.ReqStop()
|
|
task_loop:
|
|
for {
|
|
select {
|
|
case <-s.stop_chan:
|
|
break task_loop
|
|
}
|
|
}
|
|
|
|
s.ReqStop()
|
|
|
|
s.rpc_wg.Wait()
|
|
s.log.Write("", LOG_DEBUG, "All RPC listeners ended")
|
|
|
|
s.cts_wg.Wait()
|
|
s.log.Write("", LOG_DEBUG, "All CTS handlers ended")
|
|
|
|
// stop the main grpc server after all the other tasks are finished.
|
|
s.rpc_svr.Stop()
|
|
}
|
|
|
|
func (s *Server) RunCtlTask(wg *sync.WaitGroup) {
|
|
var err error
|
|
var ctl *http.Server
|
|
var idx int
|
|
var l_wg sync.WaitGroup
|
|
|
|
defer wg.Done()
|
|
|
|
for idx, ctl = range s.ctl {
|
|
l_wg.Add(1)
|
|
go func(i int, cs *http.Server) {
|
|
var l net.Listener
|
|
|
|
s.log.Write("", LOG_INFO, "Control channel[%d] started on %s", i, s.ctl_addr[i])
|
|
|
|
if s.stop_req.Load() == false {
|
|
// defeat hard-coded "tcp" in ListenAndServe() and ListenAndServeTLS()
|
|
// err = cs.ListenAndServe()
|
|
// err = cs.ListenAndServeTLS("", "")
|
|
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
|
|
if err == nil {
|
|
if s.stop_req.Load() == false {
|
|
if s.ctltlscfg == nil {
|
|
err = cs.Serve(l)
|
|
} else {
|
|
err = cs.ServeTLS(l, "", "") // s.ctltlscfg must provide a certificate and a key
|
|
}
|
|
} else {
|
|
err = fmt.Errorf("stop requested")
|
|
}
|
|
l.Close()
|
|
}
|
|
} else {
|
|
err = fmt.Errorf("stop requested")
|
|
}
|
|
if errors.Is(err, http.ErrServerClosed) {
|
|
s.log.Write("", LOG_INFO, "Control channel[%d] ended", i)
|
|
} else {
|
|
s.log.Write("", LOG_ERROR, "Control channel[%d] error - %s", i, err.Error())
|
|
}
|
|
l_wg.Done()
|
|
}(idx, ctl)
|
|
}
|
|
l_wg.Wait()
|
|
}
|
|
|
|
func (s *Server) RunPxyTask(wg *sync.WaitGroup) {
|
|
var err error
|
|
var pxy *http.Server
|
|
var idx int
|
|
var l_wg sync.WaitGroup
|
|
|
|
defer wg.Done()
|
|
|
|
for idx, pxy = range s.pxy {
|
|
l_wg.Add(1)
|
|
go func(i int, cs *http.Server) {
|
|
var l net.Listener
|
|
|
|
s.log.Write("", LOG_INFO, "Proxy channel[%d] started on %s", i, s.pxy_addr[i])
|
|
|
|
if s.stop_req.Load() == false {
|
|
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
|
|
if err == nil {
|
|
if s.stop_req.Load() == false {
|
|
if s.pxytlscfg == nil { // TODO: change this
|
|
err = cs.Serve(l)
|
|
} else {
|
|
err = cs.ServeTLS(l, "", "") // s.pxytlscfg must provide a certificate and a key
|
|
}
|
|
} else {
|
|
err = fmt.Errorf("stop requested")
|
|
}
|
|
l.Close()
|
|
}
|
|
} else {
|
|
err = fmt.Errorf("stop requested")
|
|
}
|
|
if errors.Is(err, http.ErrServerClosed) {
|
|
s.log.Write("", LOG_INFO, "Proxy channel[%d] ended", i)
|
|
} else {
|
|
s.log.Write("", LOG_ERROR, "Proxy channel[%d] error - %s", i, err.Error())
|
|
}
|
|
l_wg.Done()
|
|
}(idx, pxy)
|
|
}
|
|
l_wg.Wait()
|
|
}
|
|
|
|
func (s *Server) RunWpxTask(wg *sync.WaitGroup) {
|
|
var err error
|
|
var wpx *http.Server
|
|
var idx int
|
|
var l_wg sync.WaitGroup
|
|
|
|
defer wg.Done()
|
|
|
|
for idx, wpx = range s.wpx {
|
|
l_wg.Add(1)
|
|
go func(i int, cs *http.Server) {
|
|
var l net.Listener
|
|
|
|
s.log.Write("", LOG_INFO, "Wpx channel[%d] started on %s", i, s.wpx_addr[i])
|
|
|
|
if s.stop_req.Load() == false {
|
|
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
|
|
if err == nil {
|
|
if s.stop_req.Load() == false {
|
|
if s.wpxtlscfg == nil { // TODO: change this
|
|
err = cs.Serve(l)
|
|
} else {
|
|
err = cs.ServeTLS(l, "", "") // s.wpxtlscfg must provide a certificate and a key
|
|
}
|
|
} else {
|
|
err = fmt.Errorf("stop requested")
|
|
}
|
|
l.Close()
|
|
}
|
|
} else {
|
|
err = fmt.Errorf("stop requested")
|
|
}
|
|
if errors.Is(err, http.ErrServerClosed) {
|
|
s.log.Write("", LOG_INFO, "Wpx channel[%d] ended", i)
|
|
} else {
|
|
s.log.Write("", LOG_ERROR, "Wpx channel[%d] error - %s", i, err.Error())
|
|
}
|
|
l_wg.Done()
|
|
}(idx, wpx)
|
|
}
|
|
l_wg.Wait()
|
|
}
|
|
|
|
func (s *Server) ReqStop() {
|
|
if s.stop_req.CompareAndSwap(false, true) {
|
|
var l *net.TCPListener
|
|
var cts *ServerConn
|
|
var hs *http.Server
|
|
|
|
// call cancellation function before anything else
|
|
// to break sub-tasks relying on this server context.
|
|
// for example, http.Client in server_proxy_http_main
|
|
s.ctx_cancel()
|
|
|
|
for _, hs = range s.ctl {
|
|
hs.Shutdown(s.ctx) // to break s.ctl.Serve()
|
|
}
|
|
|
|
for _, hs = range s.pxy {
|
|
hs.Shutdown(s.ctx) // to break s.pxy.Serve()
|
|
}
|
|
|
|
for _, hs = range s.wpx {
|
|
hs.Shutdown(s.ctx) // to break s.wpx.Serve()
|
|
}
|
|
|
|
//s.rpc_svr.GracefulStop()
|
|
//s.rpc_svr.Stop()
|
|
for _, l = range s.rpc {
|
|
l.Close()
|
|
}
|
|
|
|
// request to stop connections from/to peer held in the cts structure
|
|
s.cts_mtx.Lock()
|
|
for _, cts = range s.cts_map { cts.ReqStop() }
|
|
s.cts_mtx.Unlock()
|
|
|
|
s.stop_chan <- true
|
|
}
|
|
}
|
|
|
|
func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, pss Hodu_PacketStreamServer) (*ServerConn, error) {
|
|
var cts ServerConn
|
|
var start_id ConnId
|
|
var assigned_id ConnId
|
|
var ok bool
|
|
|
|
cts.svr = s
|
|
cts.route_map = make(ServerRouteMap)
|
|
cts.RemoteAddr = *remote_addr
|
|
cts.LocalAddr = *local_addr
|
|
cts.pss = &GuardedPacketStreamServer{Hodu_PacketStreamServer: pss}
|
|
|
|
cts.stop_req.Store(false)
|
|
cts.stop_chan = make(chan bool, 8)
|
|
|
|
s.cts_mtx.Lock()
|
|
|
|
if s.cts_limit > 0 && len(s.cts_map) >= s.cts_limit {
|
|
s.cts_mtx.Unlock()
|
|
return nil, fmt.Errorf("too many connections - %d", s.cts_limit)
|
|
}
|
|
|
|
//start_id = rand.Uint64()
|
|
//start_id = ConnId(monotonic_time() / 1000)
|
|
start_id = s.cts_next_id
|
|
for {
|
|
_, ok = s.cts_map[s.cts_next_id]
|
|
if !ok {
|
|
assigned_id = s.cts_next_id
|
|
s.cts_next_id++
|
|
if s.cts_next_id == 0 { s.cts_next_id++ }
|
|
break
|
|
}
|
|
s.cts_next_id++
|
|
if s.cts_next_id == 0 { s.cts_next_id++ }
|
|
if s.cts_next_id == start_id {
|
|
s.cts_mtx.Unlock()
|
|
return nil, fmt.Errorf("unable to assign id")
|
|
}
|
|
}
|
|
cts.Id = assigned_id
|
|
cts.sid = fmt.Sprintf("%d", cts.Id) // id in string used for logging
|
|
|
|
_, ok = s.cts_map_by_addr[cts.RemoteAddr]
|
|
if ok {
|
|
s.cts_mtx.Unlock()
|
|
return nil, fmt.Errorf("existing client - %s", cts.RemoteAddr.String())
|
|
}
|
|
s.cts_map_by_addr[cts.RemoteAddr] = &cts
|
|
s.cts_map[cts.Id] = &cts
|
|
s.stats.conns.Store(int64(len(s.cts_map)))
|
|
s.cts_mtx.Unlock()
|
|
|
|
s.log.Write(cts.sid, LOG_DEBUG, "Added client connection from %s", cts.RemoteAddr.String())
|
|
return &cts, nil
|
|
}
|
|
|
|
func (s *Server) ReqStopAllServerConns() {
|
|
var cts *ServerConn
|
|
s.cts_mtx.Lock()
|
|
for _, cts = range s.cts_map { cts.ReqStop() }
|
|
s.cts_mtx.Unlock()
|
|
}
|
|
|
|
func (s *Server) RemoveServerConn(cts *ServerConn) error {
|
|
var conn *ServerConn
|
|
var ok bool
|
|
|
|
s.cts_mtx.Lock()
|
|
|
|
conn, ok = s.cts_map[cts.Id]
|
|
if !ok {
|
|
s.cts_mtx.Unlock()
|
|
return fmt.Errorf("non-existent connection id - %d", cts.Id)
|
|
}
|
|
if conn != cts {
|
|
s.cts_mtx.Unlock()
|
|
return fmt.Errorf("non-existent connection id - %d", cts.Id)
|
|
}
|
|
|
|
delete(s.cts_map, cts.Id)
|
|
delete(s.cts_map_by_addr, cts.RemoteAddr)
|
|
s.stats.conns.Store(int64(len(s.cts_map)))
|
|
s.cts_mtx.Unlock()
|
|
|
|
cts.ReqStop()
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) RemoveServerConnByAddr(addr net.Addr) (*ServerConn, error) {
|
|
var cts *ServerConn
|
|
var ok bool
|
|
|
|
s.cts_mtx.Lock()
|
|
|
|
cts, ok = s.cts_map_by_addr[addr]
|
|
if !ok {
|
|
s.cts_mtx.Unlock()
|
|
return nil, fmt.Errorf("non-existent connection address - %s", addr.String())
|
|
}
|
|
delete(s.cts_map, cts.Id)
|
|
delete(s.cts_map_by_addr, cts.RemoteAddr)
|
|
s.stats.conns.Store(int64(len(s.cts_map)))
|
|
s.cts_mtx.Unlock()
|
|
|
|
cts.ReqStop()
|
|
return cts, nil
|
|
}
|
|
|
|
func (s *Server) FindServerConnById(id ConnId) *ServerConn {
|
|
var cts *ServerConn
|
|
var ok bool
|
|
|
|
s.cts_mtx.Lock()
|
|
defer s.cts_mtx.Unlock()
|
|
|
|
cts, ok = s.cts_map[id]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
return cts
|
|
}
|
|
|
|
func (s *Server) FindServerConnByAddr(addr net.Addr) *ServerConn {
|
|
var cts *ServerConn
|
|
var ok bool
|
|
|
|
s.cts_mtx.Lock()
|
|
defer s.cts_mtx.Unlock()
|
|
|
|
cts, ok = s.cts_map_by_addr[addr]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
return cts
|
|
}
|
|
|
|
func (s *Server) FindServerRouteById(id ConnId, route_id RouteId) *ServerRoute {
|
|
var cts *ServerConn
|
|
var ok bool
|
|
|
|
s.cts_mtx.Lock()
|
|
defer s.cts_mtx.Unlock()
|
|
|
|
cts, ok = s.cts_map[id]
|
|
if !ok { return nil }
|
|
|
|
return cts.FindServerRouteById(route_id)
|
|
}
|
|
|
|
func (s *Server) FindServerRouteByPortId(port_id PortId) *ServerRoute {
|
|
var cri ConnRouteId
|
|
var ok bool
|
|
|
|
s.svc_port_mtx.Lock()
|
|
defer s.svc_port_mtx.Unlock()
|
|
|
|
cri, ok = s.svc_port_map[port_id]
|
|
if !ok { return nil }
|
|
return s.FindServerRouteById(cri.conn_id, cri.route_id)
|
|
}
|
|
|
|
func (s *Server) FindServerRouteByIdStr(conn_id string, route_id string) (*ServerRoute, error) {
|
|
var r *ServerRoute
|
|
var err error
|
|
|
|
if route_id == PORT_ID_MARKER {
|
|
var port_nid uint64
|
|
|
|
port_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(PortId(0)) * 8))
|
|
if err != nil { return nil, fmt.Errorf("invalid port id %s - %s", conn_id, err.Error()) }
|
|
|
|
r = s.FindServerRouteByPortId(PortId(port_nid))
|
|
if r == nil { return nil, fmt.Errorf("port(%d) not found", port_nid) }
|
|
} else {
|
|
var conn_nid uint64
|
|
var route_nid uint64
|
|
|
|
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
|
|
if err != nil { return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error()) }
|
|
|
|
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
|
|
if err != nil { return nil, fmt.Errorf("invalid route id %s - %s", route_id, err.Error()) }
|
|
|
|
r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
|
|
if r == nil { return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) }
|
|
}
|
|
|
|
return r, nil
|
|
}
|
|
|
|
func (s *Server) FindServerConnByIdStr(conn_id string) (*ServerConn, error) {
|
|
var conn_nid uint64
|
|
var cts *ServerConn
|
|
var err error
|
|
|
|
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
|
|
if err != nil { return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error()); }
|
|
|
|
cts = s.FindServerConnById(ConnId(conn_nid))
|
|
if cts == nil { return nil, fmt.Errorf("non-existent connection id %d", conn_nid) }
|
|
|
|
return cts, nil
|
|
}
|
|
|
|
func (s *Server) StartService(cfg interface{}) {
|
|
s.wg.Add(1)
|
|
go s.RunTask(&s.wg)
|
|
}
|
|
|
|
func (s *Server) StartExtService(svc Service, data interface{}) {
|
|
s.ext_mtx.Lock()
|
|
s.ext_svcs = append(s.ext_svcs, svc)
|
|
s.ext_mtx.Unlock()
|
|
s.wg.Add(1)
|
|
go svc.RunTask(&s.wg)
|
|
}
|
|
|
|
func (s *Server) StartCtlService() {
|
|
s.wg.Add(1)
|
|
go s.RunCtlTask(&s.wg)
|
|
}
|
|
|
|
func (s *Server) StartPxyService() {
|
|
s.wg.Add(1)
|
|
go s.RunPxyTask(&s.wg)
|
|
}
|
|
|
|
func (s *Server) StartWpxService() {
|
|
s.wg.Add(1)
|
|
go s.RunWpxTask(&s.wg)
|
|
}
|
|
|
|
func (s *Server) StopServices() {
|
|
var ext_svc Service
|
|
s.ReqStop()
|
|
for _, ext_svc = range s.ext_svcs {
|
|
ext_svc.StopServices()
|
|
}
|
|
}
|
|
|
|
func (s *Server) FixServices() {
|
|
s.log.Rotate()
|
|
}
|
|
|
|
func (s *Server) WaitForTermination() {
|
|
s.wg.Wait()
|
|
}
|
|
|
|
func (s *Server) WriteLog(id string, level LogLevel, fmtstr string, args ...interface{}) {
|
|
s.log.Write(id, level, fmtstr, args...)
|
|
}
|