added code dealing with client connection token
This commit is contained in:
parent
7a6b820b92
commit
d9aaa0a0ab
@ -1013,8 +1013,10 @@ start_over:
|
|||||||
if cts.C.token != "" {
|
if cts.C.token != "" {
|
||||||
err = cts.psc.Send(MakeConnDescPacket(cts.C.token))
|
err = cts.psc.Send(MakeConnDescPacket(cts.C.token))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send conn-desc to server[%d] %s - %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index], err.Error())
|
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send conn-desc(%s) to server[%d] %s - %s", cts.C.token, cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index], err.Error())
|
||||||
goto reconnect_to_server
|
goto reconnect_to_server
|
||||||
|
} else {
|
||||||
|
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Sending conn-desc(%s) to server[%d] %s", cts.C.token, cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1191,7 +1193,7 @@ start_over:
|
|||||||
cts.C.conn_notice.Handle(cts, x.Notice.Text)
|
cts.C.conn_notice.Handle(cts, x.Notice.Text)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_data event from %s", cts.remote_addr)
|
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_notice packet from %s", cts.remote_addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
92
server.go
92
server.go
@ -34,6 +34,7 @@ const HS_ID_PXY string = "pxy"
|
|||||||
const HS_ID_PXY_WS string = "pxy-ws"
|
const HS_ID_PXY_WS string = "pxy-ws"
|
||||||
|
|
||||||
type ServerConnMapByAddr = map[net.Addr]*ServerConn
|
type ServerConnMapByAddr = map[net.Addr]*ServerConn
|
||||||
|
type ServerConnMapByToken = map[string]*ServerConn
|
||||||
type ServerConnMap = map[ConnId]*ServerConn
|
type ServerConnMap = map[ConnId]*ServerConn
|
||||||
type ServerRouteMap = map[RouteId]*ServerRoute
|
type ServerRouteMap = map[RouteId]*ServerRoute
|
||||||
type ServerPeerConnMap = map[PeerId]*ServerPeerConn
|
type ServerPeerConnMap = map[PeerId]*ServerPeerConn
|
||||||
@ -101,6 +102,7 @@ type Server struct {
|
|||||||
cts_mtx sync.Mutex
|
cts_mtx sync.Mutex
|
||||||
cts_map ServerConnMap
|
cts_map ServerConnMap
|
||||||
cts_map_by_addr ServerConnMapByAddr
|
cts_map_by_addr ServerConnMapByAddr
|
||||||
|
cts_map_by_token ServerConnMapByToken
|
||||||
cts_wg sync.WaitGroup
|
cts_wg sync.WaitGroup
|
||||||
|
|
||||||
log Logger
|
log Logger
|
||||||
@ -750,14 +752,33 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) {
|
|||||||
var ok bool
|
var ok bool
|
||||||
x, ok = pkt.U.(*Packet_Conn)
|
x, ok = pkt.U.(*Packet_Conn)
|
||||||
if ok {
|
if ok {
|
||||||
cts.Token = x.Conn.Token
|
if x.Conn.Token == "" {
|
||||||
// TODO: lock
|
cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_desc packet from %s - blank token", cts.RemoteAddr)
|
||||||
// manipulate the cts_map_by_token
|
cts.ReqStop() // TODO: is this desirable to disconnect?
|
||||||
// if cts.Token is empty, not placed in the table...
|
} else if x.Conn.Token != cts.Token {
|
||||||
// removal by old token value before adding it.
|
_, err = strconv.ParseUint(x.Conn.Token, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
|
||||||
// unlock
|
if err == nil {
|
||||||
|
cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_desc packet from %s - numeric token '%s'", cts.RemoteAddr, x.Conn.Token)
|
||||||
|
cts.ReqStop() // TODO: is this desirable to disconnect?
|
||||||
} else {
|
} else {
|
||||||
cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_desc event from %s", cts.RemoteAddr)
|
cts.S.cts_mtx.Lock()
|
||||||
|
_, ok = cts.S.cts_map_by_token[x.Conn.Token]
|
||||||
|
if ok {
|
||||||
|
// error
|
||||||
|
cts.S.cts_mtx.Unlock()
|
||||||
|
cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_desc packet from %s - duplicate token '%s'", cts.RemoteAddr, x.Conn.Token)
|
||||||
|
cts.ReqStop() // TODO: is this desirable to disconnect?
|
||||||
|
} else {
|
||||||
|
if cts.Token != "" { delete(cts.S.cts_map_by_token, cts.Token) }
|
||||||
|
cts.Token = x.Conn.Token
|
||||||
|
cts.S.cts_map_by_token[x.Conn.Token] = cts
|
||||||
|
cts.S.cts_mtx.Unlock()
|
||||||
|
cts.S.log.Write(cts.Sid, LOG_INFO, "client(%d) %s - token set to '%s'", cts.Id, cts.RemoteAddr, x.Conn.Token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_desc packet from %s", cts.RemoteAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
case PACKET_KIND_CONN_NOTICE:
|
case PACKET_KIND_CONN_NOTICE:
|
||||||
@ -770,7 +791,7 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) {
|
|||||||
cts.S.conn_notice.Handle(cts, x.Notice.Text)
|
cts.S.conn_notice.Handle(cts, x.Notice.Text)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_data event from %s", cts.RemoteAddr)
|
cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_notice packet from %s", cts.RemoteAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1107,6 +1128,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
|
|||||||
s.cts_next_id = 1
|
s.cts_next_id = 1
|
||||||
s.cts_map = make(ServerConnMap)
|
s.cts_map = make(ServerConnMap)
|
||||||
s.cts_map_by_addr = make(ServerConnMapByAddr)
|
s.cts_map_by_addr = make(ServerConnMapByAddr)
|
||||||
|
s.cts_map_by_token = make(ServerConnMapByToken)
|
||||||
s.svc_port_map = make(ServerSvcPortMap)
|
s.svc_port_map = make(ServerSvcPortMap)
|
||||||
s.stop_chan = make(chan bool, 8)
|
s.stop_chan = make(chan bool, 8)
|
||||||
s.stop_req.Store(false)
|
s.stop_req.Store(false)
|
||||||
@ -1562,7 +1584,17 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p
|
|||||||
_, ok = s.cts_map_by_addr[cts.RemoteAddr]
|
_, ok = s.cts_map_by_addr[cts.RemoteAddr]
|
||||||
if ok {
|
if ok {
|
||||||
s.cts_mtx.Unlock()
|
s.cts_mtx.Unlock()
|
||||||
return nil, fmt.Errorf("existing client - %s", cts.RemoteAddr.String())
|
return nil, fmt.Errorf("existing client address - %s", cts.RemoteAddr.String())
|
||||||
|
}
|
||||||
|
if cts.Token != "" {
|
||||||
|
// this check is not needed as Token is never set at this phase
|
||||||
|
// however leave statements here for completeness
|
||||||
|
_, ok = s.cts_map_by_token[cts.Token]
|
||||||
|
if ok {
|
||||||
|
s.cts_mtx.Unlock()
|
||||||
|
return nil, fmt.Errorf("existing client token - %s", cts.Token)
|
||||||
|
}
|
||||||
|
s.cts_map_by_token[cts.Token] = &cts
|
||||||
}
|
}
|
||||||
s.cts_map_by_addr[cts.RemoteAddr] = &cts
|
s.cts_map_by_addr[cts.RemoteAddr] = &cts
|
||||||
s.cts_map[cts.Id] = &cts
|
s.cts_map[cts.Id] = &cts
|
||||||
@ -1598,6 +1630,7 @@ func (s *Server) RemoveServerConn(cts *ServerConn) error {
|
|||||||
|
|
||||||
delete(s.cts_map, cts.Id)
|
delete(s.cts_map, cts.Id)
|
||||||
delete(s.cts_map_by_addr, cts.RemoteAddr)
|
delete(s.cts_map_by_addr, cts.RemoteAddr)
|
||||||
|
if cts.Token != "" { delete(s.cts_map_by_token, cts.Token) }
|
||||||
s.stats.conns.Store(int64(len(s.cts_map)))
|
s.stats.conns.Store(int64(len(s.cts_map)))
|
||||||
s.cts_mtx.Unlock()
|
s.cts_mtx.Unlock()
|
||||||
|
|
||||||
@ -1618,6 +1651,28 @@ func (s *Server) RemoveServerConnByAddr(addr net.Addr) (*ServerConn, error) {
|
|||||||
}
|
}
|
||||||
delete(s.cts_map, cts.Id)
|
delete(s.cts_map, cts.Id)
|
||||||
delete(s.cts_map_by_addr, cts.RemoteAddr)
|
delete(s.cts_map_by_addr, cts.RemoteAddr)
|
||||||
|
if cts.Token != "" { delete(s.cts_map_by_token, cts.Token) }
|
||||||
|
s.stats.conns.Store(int64(len(s.cts_map)))
|
||||||
|
s.cts_mtx.Unlock()
|
||||||
|
|
||||||
|
cts.ReqStop()
|
||||||
|
return cts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) RemoveServerConnByToken(token string) (*ServerConn, error) {
|
||||||
|
var cts *ServerConn
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
s.cts_mtx.Lock()
|
||||||
|
|
||||||
|
cts, ok = s.cts_map_by_token[token]
|
||||||
|
if !ok {
|
||||||
|
s.cts_mtx.Unlock()
|
||||||
|
return nil, fmt.Errorf("non-existent connection token - %s", token)
|
||||||
|
}
|
||||||
|
delete(s.cts_map, cts.Id)
|
||||||
|
delete(s.cts_map_by_addr, cts.RemoteAddr)
|
||||||
|
delete(s.cts_map_by_token, cts.Token) // no Empty check becuase an empty token is never found in the map
|
||||||
s.stats.conns.Store(int64(len(s.cts_map)))
|
s.stats.conns.Store(int64(len(s.cts_map)))
|
||||||
s.cts_mtx.Unlock()
|
s.cts_mtx.Unlock()
|
||||||
|
|
||||||
@ -1633,9 +1688,7 @@ func (s *Server) FindServerConnById(id ConnId) *ServerConn {
|
|||||||
defer s.cts_mtx.Unlock()
|
defer s.cts_mtx.Unlock()
|
||||||
|
|
||||||
cts, ok = s.cts_map[id]
|
cts, ok = s.cts_map[id]
|
||||||
if !ok {
|
if !ok { return nil }
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return cts
|
return cts
|
||||||
}
|
}
|
||||||
@ -1648,10 +1701,21 @@ func (s *Server) FindServerConnByAddr(addr net.Addr) *ServerConn {
|
|||||||
defer s.cts_mtx.Unlock()
|
defer s.cts_mtx.Unlock()
|
||||||
|
|
||||||
cts, ok = s.cts_map_by_addr[addr]
|
cts, ok = s.cts_map_by_addr[addr]
|
||||||
if !ok {
|
if !ok { return nil }
|
||||||
return nil
|
|
||||||
|
return cts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) FindServerConnByToken(token string) *ServerConn {
|
||||||
|
var cts *ServerConn
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
s.cts_mtx.Lock()
|
||||||
|
defer s.cts_mtx.Unlock()
|
||||||
|
|
||||||
|
cts, ok = s.cts_map_by_token[token]
|
||||||
|
if !ok { return nil }
|
||||||
|
|
||||||
return cts
|
return cts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user