diff --git a/README.md b/README.md index 408e163..88ccdcf 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ "server-peer-option": "tcp4 ssh", "server-peer-service-addr": "0.0.0.0:0", "server-peer-service-net": "", - "lifetime": "0s" + "lifetime": "0" } ``` diff --git a/client-ctl.go b/client-ctl.go index ba78d54..227611f 100644 --- a/client-ctl.go +++ b/client-ctl.go @@ -41,6 +41,10 @@ type json_in_client_route struct { Lifetime string `json:"lifetime"` } +type json_in_client_route_update struct { + Lifetime string `sjon:"lifetime"` +} + type json_out_client_conn_id struct { Id ConnId `json:"id"` } @@ -55,7 +59,8 @@ type json_out_client_conn struct { } type json_out_client_route_id struct { - Id RouteId `json:"id"` + Id RouteId `json:"id"` + CtsId ConnId `json:"conn-id"` } type json_out_client_route struct { @@ -66,6 +71,7 @@ type json_out_client_route struct { ServerPeerListenAddr string `json:"server-peer-service-addr"` ServerPeerNet string `json:"server-peer-service-net"` Lifetime string `json:"lifetime"` + LifetimeStart int64 `json:"lifetime-start"` } type json_out_client_peer struct { @@ -176,7 +182,8 @@ func (ctl *client_ctl_client_conns) ServeHTTP(w http.ResponseWriter, req *http.R ServerPeerListenAddr: r.server_peer_listen_addr.String(), ServerPeerNet: r.server_peer_net, ServerPeerOption: r.server_peer_option.string(), - Lifetime: r.lifetime.String(), + Lifetime: fmt.Sprintf("%.09f", r.lifetime.Seconds()), + LifetimeStart: r.lifetime_start.Unix(), }) } js = append(js, json_out_client_conn{ @@ -293,7 +300,8 @@ func (ctl *client_ctl_client_conns_id) ServeHTTP(w http.ResponseWriter, req *htt ServerPeerListenAddr: r.server_peer_listen_addr.String(), ServerPeerNet: r.server_peer_net, ServerPeerOption: r.server_peer_option.string(), - Lifetime: r.lifetime.String(), + Lifetime: fmt.Sprintf("%.09f", r.lifetime.Seconds()), + LifetimeStart: r.lifetime_start.Unix(), }) } js = &json_out_client_conn{ @@ -375,7 +383,8 @@ func (ctl *client_ctl_client_conns_id_routes) ServeHTTP(w http.ResponseWriter, r ServerPeerListenAddr: r.server_peer_listen_addr.String(), ServerPeerNet: r.server_peer_net, ServerPeerOption: r.server_peer_option.string(), - Lifetime: r.lifetime.String(), + Lifetime: fmt.Sprintf("%.09f", r.lifetime.Seconds()), + LifetimeStart: r.lifetime_start.Unix(), }) } cts.route_mtx.Unlock() @@ -409,13 +418,11 @@ func (ctl *client_ctl_client_conns_id_routes) ServeHTTP(w http.ResponseWriter, r goto oops } - if jcr.Lifetime != "" { - lifetime, err = time.ParseDuration(jcr.Lifetime) - if err != nil { - status_code = http.StatusBadRequest; w.WriteHeader(status_code) - err = fmt.Errorf("wrong lifetime value %s - %s", jcr.Lifetime, err.Error()) - goto oops - } + lifetime, err = parse_duration_string(jcr.Lifetime) + if err != nil { + status_code = http.StatusBadRequest; w.WriteHeader(status_code) + err = fmt.Errorf("wrong lifetime value %s - %s", jcr.Lifetime, err.Error()) + goto oops } rc = &ClientRouteConfig{ @@ -434,7 +441,7 @@ func (ctl *client_ctl_client_conns_id_routes) ServeHTTP(w http.ResponseWriter, r if err = je.Encode(json_errmsg{Text: err.Error()}); err != nil { goto oops } } else { status_code = http.StatusCreated; w.WriteHeader(status_code) - if err = je.Encode(json_out_client_route_id{Id: r.id}); err != nil { goto oops } + if err = je.Encode(json_out_client_route_id{Id: r.id, CtsId: r.cts.id}); err != nil { goto oops } } case http.MethodDelete: @@ -519,6 +526,27 @@ func (ctl *client_ctl_client_conns_id_routes_id) ServeHTTP(w http.ResponseWriter }) if err != nil { goto oops } + case http.MethodPut: + var jcr json_in_client_route_update + var lifetime time.Duration + + err = json.NewDecoder(req.Body).Decode(&jcr) + if err != nil { + status_code = http.StatusBadRequest; w.WriteHeader(status_code) + goto oops + } + + lifetime, err = parse_duration_string(jcr.Lifetime) + if err != nil { + status_code = http.StatusBadRequest; w.WriteHeader(status_code) + err = fmt.Errorf("wrong lifetime value %s - %s", jcr.Lifetime, err.Error()) + goto oops + } + + + err = r.ResetLifetime(lifetime) + if err != nil { goto oops } + case http.MethodDelete: r.ReqStop() status_code = http.StatusNoContent; w.WriteHeader(status_code) diff --git a/client.go b/client.go index eef4541..2e1de98 100644 --- a/client.go +++ b/client.go @@ -125,7 +125,9 @@ type ClientRoute struct { ptc_wg sync.WaitGroup lifetime time.Duration + lifetime_start time.Time lifetime_timer *time.Timer + lifetime_mtx sync.Mutex stop_req atomic.Bool stop_chan chan bool @@ -183,6 +185,7 @@ func NewClientRoute(cts *ClientConn, id RouteId, client_peer_addr string, client r.server_peer_addr = server_peer_svc_addr r.server_peer_net = server_peer_svc_net // permitted network for server-side peer r.server_peer_option = server_peer_option + r.lifetime_start = time.Now() r.lifetime = lifetime r.stop_req.Store(false) r.stop_chan = make(chan bool, 8) @@ -262,6 +265,22 @@ func (r *ClientRoute) FindClientPeerConnById(conn_id PeerId) *ClientPeerConn { return c } +func (r *ClientRoute) ResetLifetime(lifetime time.Duration) error { + r.lifetime_mtx.Lock() + defer r.lifetime_mtx.Unlock() + if r.lifetime_timer == nil { + // let's not support timer reset if route was not + // first started with lifetime enabled + return fmt.Errorf("prohibited operation") + } else { + r.lifetime_timer.Stop() + r.lifetime = lifetime + r.lifetime_start = time.Now() + r.lifetime_timer.Reset(lifetime) + return nil + } +} + func (r *ClientRoute) RunTask(wg *sync.WaitGroup) { var err error @@ -282,7 +301,12 @@ func (r *ClientRoute) RunTask(wg *sync.WaitGroup) { r.id, r.peer_addr, r.server_peer_option, r.server_peer_net, r.cts.remote_addr) } - if r.lifetime > 0 { r.lifetime_timer = time.NewTimer(r.lifetime) } + r.lifetime_mtx.Lock() + if r.lifetime > 0 { + r.lifetime_start = time.Now() + r.lifetime_timer = time.NewTimer(r.lifetime) + } + r.lifetime_mtx.Unlock() main_loop: for { @@ -304,7 +328,12 @@ main_loop: } } - if r.lifetime_timer != nil { r.lifetime_timer.Stop() } + r.lifetime_mtx.Lock() + if r.lifetime_timer != nil { + r.lifetime_timer.Stop() + r.lifetime_timer = nil + } + r.lifetime_mtx.Unlock() done: r.ReqStop() diff --git a/hodu.go b/hodu.go index 0c57319..c9db838 100644 --- a/hodu.go +++ b/hodu.go @@ -7,6 +7,7 @@ import "os" import "runtime" import "strings" import "sync" +import "time" const HODU_RPC_VERSION uint32 = 0x010000 @@ -130,3 +131,27 @@ func svc_addr_to_dst_addr (svc_addr *net.TCPAddr) *net.TCPAddr { return &addr } + +func is_digit_or_period(r rune) bool { + return (r >= '0' && r <= '9') || r == '.' +} + +func get_last_rune_of_non_empty_string(s string) rune { + var tmp []rune + // the string must not be blank for this to work + tmp = []rune(s) + return tmp[len(tmp) - 1] +} + +func parse_duration_string(dur string) (time.Duration, error) { + // i want the input to be in seconds with resolution of 9 digits after + // the decimal point. For example, 0.05 to mean 500ms. + // however, i don't care if a unit is part of the input. + var tmp string + + if dur == "" { return 0, nil } + + tmp = dur + if is_digit_or_period(get_last_rune_of_non_empty_string(tmp)) { tmp = tmp + "s" } + return time.ParseDuration(tmp) +}