fix: harden oidc compat email and email bind tx
This commit is contained in:
@@ -6,7 +6,10 @@ import (
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
@@ -55,6 +58,13 @@ func (s *AuthService) BindEmailIdentity(
|
||||
}
|
||||
|
||||
firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email)
|
||||
if firstRealEmailBind && s.entClient != nil {
|
||||
if err := s.bindEmailIdentityWithDefaultsTx(ctx, currentUser, normalizedEmail, hashedPassword); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return currentUser, nil
|
||||
}
|
||||
|
||||
currentUser.Email = normalizedEmail
|
||||
currentUser.PasswordHash = hashedPassword
|
||||
if err := s.userRepo.Update(ctx, currentUser); err != nil {
|
||||
@@ -126,3 +136,162 @@ func hasBindableEmailIdentitySubject(email string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
return normalized != "" && !isReservedEmail(normalized)
|
||||
}
|
||||
|
||||
func (s *AuthService) bindEmailIdentityWithDefaultsTx(
|
||||
ctx context.Context,
|
||||
currentUser *User,
|
||||
email string,
|
||||
hashedPassword string,
|
||||
) error {
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
return s.bindEmailIdentityWithDefaults(ctx, tx.Client(), currentUser, email, hashedPassword)
|
||||
}
|
||||
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
if err := s.bindEmailIdentityWithDefaults(txCtx, tx.Client(), currentUser, email, hashedPassword); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthService) bindEmailIdentityWithDefaults(
|
||||
ctx context.Context,
|
||||
client *dbent.Client,
|
||||
currentUser *User,
|
||||
email string,
|
||||
hashedPassword string,
|
||||
) error {
|
||||
if client == nil || currentUser == nil || currentUser.ID <= 0 {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
oldEmail := currentUser.Email
|
||||
if _, err := client.User.UpdateOneID(currentUser.ID).
|
||||
SetEmail(email).
|
||||
SetPasswordHash(hashedPassword).
|
||||
Save(ctx); err != nil {
|
||||
if dbent.IsConstraintError(err) {
|
||||
return ErrEmailExists
|
||||
}
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
if err := replaceBoundEmailAuthIdentityWithClient(ctx, client, currentUser.ID, oldEmail, email, "auth_service_email_bind"); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
return ErrEmailExists
|
||||
}
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil {
|
||||
return fmt.Errorf("apply email first bind defaults: %w", err)
|
||||
}
|
||||
|
||||
updatedUser, err := client.User.Get(ctx, currentUser.ID)
|
||||
if err != nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
currentUser.Email = updatedUser.Email
|
||||
currentUser.PasswordHash = updatedUser.PasswordHash
|
||||
currentUser.Balance = updatedUser.Balance
|
||||
currentUser.Concurrency = updatedUser.Concurrency
|
||||
currentUser.UpdatedAt = updatedUser.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func replaceBoundEmailAuthIdentityWithClient(
|
||||
ctx context.Context,
|
||||
client *dbent.Client,
|
||||
userID int64,
|
||||
oldEmail string,
|
||||
newEmail string,
|
||||
source string,
|
||||
) error {
|
||||
newSubject := normalizeBoundEmailAuthIdentitySubject(newEmail)
|
||||
if err := ensureBoundEmailAuthIdentityWithClient(ctx, client, userID, newSubject, source); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldSubject := normalizeBoundEmailAuthIdentitySubject(oldEmail)
|
||||
if oldSubject == "" || oldSubject == newSubject {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := client.AuthIdentity.Delete().
|
||||
Where(
|
||||
authidentity.UserIDEQ(userID),
|
||||
authidentity.ProviderTypeEQ("email"),
|
||||
authidentity.ProviderKeyEQ("email"),
|
||||
authidentity.ProviderSubjectEQ(oldSubject),
|
||||
).
|
||||
Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func ensureBoundEmailAuthIdentityWithClient(
|
||||
ctx context.Context,
|
||||
client *dbent.Client,
|
||||
userID int64,
|
||||
subject string,
|
||||
source string,
|
||||
) error {
|
||||
if client == nil || userID <= 0 || subject == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.TrimSpace(source) == "" {
|
||||
source = "auth_service_email_bind"
|
||||
}
|
||||
|
||||
if err := client.AuthIdentity.Create().
|
||||
SetUserID(userID).
|
||||
SetProviderType("email").
|
||||
SetProviderKey("email").
|
||||
SetProviderSubject(subject).
|
||||
SetVerifiedAt(time.Now().UTC()).
|
||||
SetMetadata(map[string]any{"source": strings.TrimSpace(source)}).
|
||||
OnConflictColumns(
|
||||
authidentity.FieldProviderType,
|
||||
authidentity.FieldProviderKey,
|
||||
authidentity.FieldProviderSubject,
|
||||
).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ("email"),
|
||||
authidentity.ProviderKeyEQ("email"),
|
||||
authidentity.ProviderSubjectEQ(subject),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if identity.UserID != userID {
|
||||
return ErrEmailExists
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeBoundEmailAuthIdentitySubject(email string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
if normalized == "" || isReservedEmail(normalized) {
|
||||
return ""
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ package service_test
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -34,6 +35,20 @@ func (s *emailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
|
||||
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
|
||||
}
|
||||
|
||||
type flakyEmailBindDefaultSubAssignerStub struct {
|
||||
err error
|
||||
calls []*service.AssignSubscriptionInput
|
||||
}
|
||||
|
||||
func (s *flakyEmailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
|
||||
_ context.Context,
|
||||
input *service.AssignSubscriptionInput,
|
||||
) (*service.UserSubscription, bool, error) {
|
||||
cloned := *input
|
||||
s.calls = append(s.calls, &cloned)
|
||||
return nil, false, s.err
|
||||
}
|
||||
|
||||
func newAuthServiceForEmailBind(
|
||||
t *testing.T,
|
||||
settings map[string]string,
|
||||
@@ -187,6 +202,62 @@ func TestAuthServiceBindEmailIdentity_RejectsExistingEmailOnAnotherUser(t *testi
|
||||
require.Equal(t, 0, countProviderGrantRecords(t, client, sourceUser.ID, "email", "first_bind"))
|
||||
}
|
||||
|
||||
func TestAuthServiceBindEmailIdentity_RollsBackWhenFirstBindDefaultsFail(t *testing.T) {
|
||||
assigner := &flakyEmailBindDefaultSubAssignerStub{err: errors.New("temporary assign failure")}
|
||||
cache := &emailBindCacheStub{
|
||||
data: &service.VerificationCodeData{
|
||||
Code: "123456",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
|
||||
},
|
||||
}
|
||||
svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
|
||||
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
|
||||
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
|
||||
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
|
||||
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
|
||||
}, cache, assigner)
|
||||
|
||||
ctx := context.Background()
|
||||
originalEmail := "legacy-rollback" + service.LinuxDoConnectSyntheticEmailDomain
|
||||
user, err := client.User.Create().
|
||||
SetEmail(originalEmail).
|
||||
SetUsername("legacy-rollback").
|
||||
SetPasswordHash("old-hash").
|
||||
SetBalance(2.5).
|
||||
SetConcurrency(1).
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "rollback@example.com", "123456", "new-password")
|
||||
require.ErrorContains(t, err, "apply email first bind defaults")
|
||||
require.ErrorContains(t, err, "temporary assign failure")
|
||||
require.Nil(t, updatedUser)
|
||||
|
||||
storedUser, err := client.User.Get(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalEmail, storedUser.Email)
|
||||
require.Equal(t, "old-hash", storedUser.PasswordHash)
|
||||
require.Equal(t, 2.5, storedUser.Balance)
|
||||
require.Equal(t, 1, storedUser.Concurrency)
|
||||
|
||||
identityCount, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.UserIDEQ(user.ID),
|
||||
authidentity.ProviderTypeEQ("email"),
|
||||
authidentity.ProviderKeyEQ("email"),
|
||||
authidentity.ProviderSubjectEQ("rollback@example.com"),
|
||||
).
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, identityCount)
|
||||
|
||||
require.Len(t, assigner.calls, 1)
|
||||
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
||||
}
|
||||
|
||||
func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) {
|
||||
cache := &emailBindCacheStub{
|
||||
data: &service.VerificationCodeData{
|
||||
|
||||
Reference in New Issue
Block a user