relocated functions

This commit is contained in:
hyung-hwan 2024-12-17 09:35:51 +09:00
parent e177b7dc6b
commit 3fc48eb590
3 changed files with 64 additions and 75 deletions

View File

@ -151,7 +151,7 @@ func (ctl *server_ctl_server_conns_id) ServeHTTP(w http.ResponseWriter, req *htt
je = json.NewEncoder(w) je = json.NewEncoder(w)
conn_id = req.PathValue("conn_id") conn_id = req.PathValue("conn_id")
cts, err = get_server_conn(s, conn_id) cts, err = s.FindServerConnByIdStr(conn_id)
if err != nil { if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code) status_code = http.StatusBadRequest; w.WriteHeader(status_code)
if err = je.Encode(json_errmsg{Text: err.Error()}); err != nil { goto oops } if err = je.Encode(json_errmsg{Text: err.Error()}); err != nil { goto oops }
@ -224,7 +224,7 @@ func (ctl *server_ctl_server_conns_id_routes) ServeHTTP(w http.ResponseWriter, r
je = json.NewEncoder(w) je = json.NewEncoder(w)
conn_id = req.PathValue("conn_id") conn_id = req.PathValue("conn_id")
cts, err = get_server_conn(s, conn_id) cts, err = s.FindServerConnByIdStr(conn_id)
if err != nil { if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code) status_code = http.StatusBadRequest; w.WriteHeader(status_code)
if err = je.Encode(json_errmsg{Text: err.Error()}); err != nil { goto oops } if err = je.Encode(json_errmsg{Text: err.Error()}); err != nil { goto oops }
@ -292,7 +292,7 @@ func (ctl *server_ctl_server_conns_id_routes_id) ServeHTTP(w http.ResponseWriter
conn_id = req.PathValue("conn_id") conn_id = req.PathValue("conn_id")
route_id = req.PathValue("route_id") route_id = req.PathValue("route_id")
r, err = get_server_route(s, conn_id, route_id) r, err = s.FindServerRouteByIdStr(conn_id, route_id)
if err != nil { if err != nil {
status_code = http.StatusNotFound; w.WriteHeader(status_code) status_code = http.StatusNotFound; w.WriteHeader(status_code)
if err = je.Encode(json_errmsg{Text: err.Error()}); err != nil { goto oops } if err = je.Encode(json_errmsg{Text: err.Error()}); err != nil { goto oops }

View File

@ -14,7 +14,6 @@ import "strings"
import "sync" import "sync"
import "text/template" import "text/template"
import "time" import "time"
import "unsafe"
import "golang.org/x/crypto/ssh" import "golang.org/x/crypto/ssh"
import "golang.org/x/net/http/httpguts" import "golang.org/x/net/http/httpguts"
@ -134,73 +133,6 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, add_path
} }
} }
func parse_conn_route_id(conn_id string, route_id string) (ConnId, RouteId, error) {
var conn_nid uint64
var route_nid uint64
var err error
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil { return ConnId(0), RouteId(0), fmt.Errorf("invalid connection id - %s", err.Error()) }
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil { return ConnId(0), RouteId(0), fmt.Errorf("invalid route id - %s", err.Error()) }
return ConnId(conn_nid), RouteId(route_nid), nil
}
func get_server_conn(s *Server, conn_id string) (*ServerConn, error) {
var conn_nid uint64
var cts *ServerConn
var err error
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error());
}
cts = s.FindServerConnById(ConnId(conn_nid))
if cts == nil {
return nil, fmt.Errorf("non-existent connection id %d", conn_nid)
}
return cts, nil
}
func get_server_route(s *Server, conn_id string, route_id string) (*ServerRoute, error) {
var r *ServerRoute
var err error
if route_id == "_" {
var port_nid uint64
port_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(PortId(0)) * 8))
if err != nil {
return nil, fmt.Errorf("invalid port id %s - %s", conn_id, err.Error())
}
r = s.FindServerRouteByPortId(PortId(port_nid))
if r == nil {
return nil, fmt.Errorf("port(%d) not found", port_nid)
}
} else {
var conn_nid uint64
var route_nid uint64
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil { return nil, fmt.Errorf("invalid connection id - %s", err.Error()) }
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil { return nil, fmt.Errorf("invalid route id - %s", err.Error()) }
r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
if r == nil {
return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid)
}
}
return r, nil
}
// ------------------------------------ // ------------------------------------
func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Request) {
@ -222,7 +154,7 @@ func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Re
conn_id = req.PathValue("conn_id") conn_id = req.PathValue("conn_id")
route_id = req.PathValue("route_id") route_id = req.PathValue("route_id")
r, err = get_server_route(s, conn_id, route_id) r, err = s.FindServerRouteByIdStr(conn_id, route_id)
if err != nil { if err != nil {
status_code = http.StatusNotFound; w.WriteHeader(status_code) status_code = http.StatusNotFound; w.WriteHeader(status_code)
goto oops goto oops
@ -306,7 +238,7 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
route_id = ids[1] route_id = ids[1]
} }
r, err = get_server_route(s, conn_id, route_id) r, err = s.FindServerRouteByIdStr(conn_id, route_id)
if err != nil { if err != nil {
status_code = http.StatusNotFound; w.WriteHeader(status_code) status_code = http.StatusNotFound; w.WriteHeader(status_code)
goto oops goto oops
@ -410,8 +342,11 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
hdr.Add("Set-Cookie", fmt.Sprintf("%s=%d-%d; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_id, route_id)) hdr.Add("Set-Cookie", fmt.Sprintf("%s=%d-%d; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_id, route_id))
} }
status_code = resp.StatusCode; w.WriteHeader(status_code) status_code = resp.StatusCode; w.WriteHeader(status_code)
// TODO" if prefixed { append prefix removed...
io.Copy(w, resp.Body) io.Copy(w, resp.Body)
// TODO: handle trailers // TODO: handle trailers
} }
@ -470,7 +405,7 @@ func (pxy *server_proxy_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.R
conn_id = req.PathValue("conn_id") conn_id = req.PathValue("conn_id")
route_id = req.PathValue("route_id") route_id = req.PathValue("route_id")
_, err = get_server_route(s, conn_id, route_id) _, err = s.FindServerRouteByIdStr(conn_id, route_id)
if err != nil { if err != nil {
status_code = http.StatusNotFound; w.WriteHeader(status_code) status_code = http.StatusNotFound; w.WriteHeader(status_code)
goto oops goto oops
@ -625,7 +560,7 @@ func (pxy *server_proxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) {
conn_id = req.PathValue("conn_id") conn_id = req.PathValue("conn_id")
route_id = req.PathValue("route_id") route_id = req.PathValue("route_id")
r, err = get_server_route(s, conn_id, route_id) r, err = s.FindServerRouteByIdStr(conn_id, route_id)
if err != nil { if err != nil {
pxy.send_ws_data(ws, "error", err.Error()) pxy.send_ws_data(ws, "error", err.Error())
goto done goto done

View File

@ -9,8 +9,10 @@ import "log"
import "net" import "net"
import "net/http" import "net/http"
import "net/netip" import "net/netip"
import "strconv"
import "sync" import "sync"
import "sync/atomic" import "sync/atomic"
import "unsafe"
import "golang.org/x/net/websocket" import "golang.org/x/net/websocket"
import "google.golang.org/grpc" import "google.golang.org/grpc"
@ -1383,6 +1385,58 @@ func (s *Server) FindServerRouteByPortId(port_id PortId) *ServerRoute {
return s.FindServerRouteById(cri.conn_id, cri.route_id) return s.FindServerRouteById(cri.conn_id, cri.route_id)
} }
func (s *Server) FindServerRouteByIdStr(conn_id string, route_id string) (*ServerRoute, error) {
var r *ServerRoute
var err error
if route_id == "_" {
var port_nid uint64
port_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(PortId(0)) * 8))
if err != nil {
return nil, fmt.Errorf("invalid port id %s - %s", conn_id, err.Error())
}
r = s.FindServerRouteByPortId(PortId(port_nid))
if r == nil {
return nil, fmt.Errorf("port(%d) not found", port_nid)
}
} else {
var conn_nid uint64
var route_nid uint64
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil { return nil, fmt.Errorf("invalid connection id - %s", err.Error()) }
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil { return nil, fmt.Errorf("invalid route id - %s", err.Error()) }
r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
if r == nil {
return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid)
}
}
return r, nil
}
func (s *Server) FindServerConnByIdStr(conn_id string) (*ServerConn, error) {
var conn_nid uint64
var cts *ServerConn
var err error
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error());
}
cts = s.FindServerConnById(ConnId(conn_nid))
if cts == nil {
return nil, fmt.Errorf("non-existent connection id %d", conn_nid)
}
return cts, nil
}
func (s *Server) StartService(cfg interface{}) { func (s *Server) StartService(cfg interface{}) {
s.wg.Add(1) s.wg.Add(1)