package auth import "context" import "fmt" import "net" import "strings" import "time" import "crypto/tls" import "codit/internal/config" import "github.com/go-ldap/ldap/v3" type LDAPUser struct { Username string DisplayName string Email string } const LDAPOperationTimeout time.Duration = 8 * time.Second func LDAPAuthenticate(cfg config.Config, username, password string) (LDAPUser, error) { return LDAPAuthenticateContext(context.Background(), cfg, username, password) } func LDAPAuthenticateContext(ctx context.Context, cfg config.Config, username, password string) (LDAPUser, error) { var conn *ldap.Conn var cleanup func() var err error var filter string var search *ldap.SearchRequest var res *ldap.SearchResult var entry *ldap.Entry var userDN string var user LDAPUser conn, cleanup, err = ldapConnWithContext(ctx, cfg) if err != nil { return LDAPUser{}, err } defer cleanup() if cfg.LDAPBindDN != "" { err = conn.Bind(cfg.LDAPBindDN, cfg.LDAPBindPassword) if err != nil { if ctx.Err() != nil { return LDAPUser{}, ctx.Err() } return LDAPUser{}, err } } filter = strings.ReplaceAll(cfg.LDAPUserFilter, "{username}", ldap.EscapeFilter(username)) search = ldap.NewSearchRequest( cfg.LDAPUserBaseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 1, 0, false, filter, []string{"dn", "cn", "mail", "uid"}, nil, ) res, err = conn.Search(search) if err != nil { if ctx.Err() != nil { return LDAPUser{}, ctx.Err() } return LDAPUser{}, err } if len(res.Entries) == 0 { return LDAPUser{}, fmt.Errorf("ldap user not found") } entry = res.Entries[0] userDN = entry.DN err = conn.Bind(userDN, password) if err != nil { if ctx.Err() != nil { return LDAPUser{}, ctx.Err() } return LDAPUser{}, err } user = LDAPUser{ Username: username, DisplayName: entry.GetAttributeValue("cn"), Email: entry.GetAttributeValue("mail"), } if user.DisplayName == "" { user.DisplayName = username } return user, nil } func LDAPTestConnection(cfg config.Config) error { return LDAPTestConnectionContext(context.Background(), cfg) } func LDAPTestConnectionContext(ctx context.Context, cfg config.Config) error { var conn *ldap.Conn var cleanup func() var err error conn, cleanup, err = ldapConnWithContext(ctx, cfg) if err != nil { return err } defer cleanup() if cfg.LDAPBindDN != "" { err = conn.Bind(cfg.LDAPBindDN, cfg.LDAPBindPassword) if err != nil { if ctx.Err() != nil { return ctx.Err() } return err } } return nil } func ldapConnWithContext(ctx context.Context, cfg config.Config) (*ldap.Conn, func(), error) { var opCtx context.Context var cancel context.CancelFunc var conn *ldap.Conn var err error var dialer *net.Dialer var tlsConfig *tls.Config var opts []ldap.DialOpt var done chan struct{} opCtx, cancel = context.WithTimeout(ctx, LDAPOperationTimeout) dialer = &net.Dialer{Timeout: LDAPOperationTimeout} opts = make([]ldap.DialOpt, 0, 2) opts = append(opts, ldap.DialWithDialer(dialer)) tlsConfig = &tls.Config{InsecureSkipVerify: cfg.LDAPTLSInsecureSkipVerify} opts = append(opts, ldap.DialWithTLSConfig(tlsConfig)) conn, err = ldap.DialURL(cfg.LDAPURL, opts...) if err != nil { cancel() return nil, nil, err } conn.SetTimeout(LDAPOperationTimeout) done = make(chan struct{}) go func() { select { case <-opCtx.Done(): _ = conn.Close() case <-done: } }() return conn, func() { close(done) cancel() _ = conn.Close() }, nil }