From fd0c9a130530eeab28c212cec601a3ef5ba1f843 Mon Sep 17 00:00:00 2001 From: erio Date: Fri, 17 Apr 2026 17:00:29 +0800 Subject: [PATCH 001/326] fix(payment): store provider config as plaintext JSON with legacy ciphertext fallback Without TOTP_ENCRYPTION_KEY, saved payment configs were lost on restart because the AES round-trip failed silently. Write new records as plaintext JSON; read path tries JSON first, falls back to legacy AES decrypt when a key is present, and treats unreadable values as empty so admins can re-enter them via the UI. --- backend/internal/payment/crypto.go | 11 ++- backend/internal/payment/load_balancer.go | 30 ++++-- .../internal/payment/load_balancer_test.go | 97 +++++++++++++++++++ .../service/payment_config_providers.go | 38 +++++--- 4 files changed, 151 insertions(+), 25 deletions(-) diff --git a/backend/internal/payment/crypto.go b/backend/internal/payment/crypto.go index e39e957f..5467e50b 100644 --- a/backend/internal/payment/crypto.go +++ b/backend/internal/payment/crypto.go @@ -10,12 +10,15 @@ import ( "strings" ) +// AES256KeySize is the required key length (in bytes) for AES-256-GCM. +const AES256KeySize = 32 + // Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key. // The output format is "iv:authTag:ciphertext" where each component is base64-encoded, // matching the Node.js crypto.ts format for cross-compatibility. func Encrypt(plaintext string, key []byte) (string, error) { - if len(key) != 32 { - return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key)) + if len(key) != AES256KeySize { + return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key)) } block, err := aes.NewCipher(key) @@ -52,8 +55,8 @@ func Encrypt(plaintext string, key []byte) (string, error) { // Decrypt decrypts a ciphertext string produced by Encrypt. // The input format is "iv:authTag:ciphertext" where each component is base64-encoded. func Decrypt(ciphertext string, key []byte) (string, error) { - if len(key) != 32 { - return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key)) + if len(key) != AES256KeySize { + return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key)) } parts := strings.SplitN(ciphertext, ":", 3) diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go index f0353173..52a1b011 100644 --- a/backend/internal/payment/load_balancer.go +++ b/backend/internal/payment/load_balancer.go @@ -261,6 +261,9 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns if err != nil { return nil, fmt.Errorf("decrypt instance %d config: %w", selected.ID, err) } + if config == nil { + config = map[string]string{} + } if selected.PaymentMode != "" { config["paymentMode"] = selected.PaymentMode @@ -275,16 +278,29 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns }, nil } -func (lb *DefaultLoadBalancer) decryptConfig(encrypted string) (map[string]string, error) { - plaintext, err := Decrypt(encrypted, lb.encryptionKey) - if err != nil { - return nil, err +// decryptConfig parses a stored provider config. +// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext. +// Unreadable values (legacy ciphertext without a valid key, or malformed data) +// are treated as empty so the service keeps running while the admin re-enters +// the config via the UI. +func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string, error) { + if stored == "" { + return nil, nil } var config map[string]string - if err := json.Unmarshal([]byte(plaintext), &config); err != nil { - return nil, fmt.Errorf("unmarshal config: %w", err) + if err := json.Unmarshal([]byte(stored), &config); err == nil { + return config, nil } - return config, nil + if len(lb.encryptionKey) == AES256KeySize { + if plaintext, err := Decrypt(stored, lb.encryptionKey); err == nil { + if err := json.Unmarshal([]byte(plaintext), &config); err == nil { + return config, nil + } + } + } + slog.Warn("payment provider config unreadable, treating as empty for re-entry", + "stored_len", len(stored)) + return nil, nil } // GetInstanceDailyAmount returns the total completed order amount for an instance today. diff --git a/backend/internal/payment/load_balancer_test.go b/backend/internal/payment/load_balancer_test.go index 04b3c25b..2bf4f6ac 100644 --- a/backend/internal/payment/load_balancer_test.go +++ b/backend/internal/payment/load_balancer_test.go @@ -452,6 +452,103 @@ func TestStartOfDay(t *testing.T) { } } +func TestDecryptConfig_PlaintextAndLegacyCompat(t *testing.T) { + t.Parallel() + + key := make([]byte, AES256KeySize) + for i := range key { + key[i] = byte(i + 1) + } + wrongKey := make([]byte, AES256KeySize) + for i := range wrongKey { + wrongKey[i] = byte(0xFF - i) + } + + plaintextJSON := `{"appId":"app-123","secret":"sec-xyz"}` + + legacyEncrypted, err := Encrypt(plaintextJSON, key) + if err != nil { + t.Fatalf("seed Encrypt: %v", err) + } + + tests := []struct { + name string + stored string + key []byte + want map[string]string + }{ + { + name: "empty stored returns nil map", + stored: "", + key: key, + want: nil, + }, + { + name: "plaintext JSON parses directly", + stored: plaintextJSON, + key: nil, + want: map[string]string{"appId": "app-123", "secret": "sec-xyz"}, + }, + { + name: "plaintext JSON works even with key present", + stored: plaintextJSON, + key: key, + want: map[string]string{"appId": "app-123", "secret": "sec-xyz"}, + }, + { + name: "legacy ciphertext with correct key decrypts", + stored: legacyEncrypted, + key: key, + want: map[string]string{"appId": "app-123", "secret": "sec-xyz"}, + }, + { + name: "legacy ciphertext with no key treated as empty", + stored: legacyEncrypted, + key: nil, + want: nil, + }, + { + name: "legacy ciphertext with wrong key treated as empty", + stored: legacyEncrypted, + key: wrongKey, + want: nil, + }, + { + name: "garbage data treated as empty", + stored: "not-json-and-not-ciphertext", + key: key, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + lb := NewDefaultLoadBalancer(nil, tt.key) + got, err := lb.decryptConfig(tt.stored) + if err != nil { + t.Fatalf("decryptConfig unexpected error: %v", err) + } + if !stringMapEqual(got, tt.want) { + t.Fatalf("decryptConfig = %v, want %v", got, tt.want) + } + }) + } +} + +// stringMapEqual compares two map[string]string values; nil and empty are equal. +func stringMapEqual(a, b map[string]string) bool { + if len(a) != len(b) { + return false + } + for k, v := range a { + if bv, ok := b[k]; !ok || bv != v { + return false + } + } + return true +} + // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 3c406b45..59337ad6 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "strconv" "strings" @@ -290,19 +291,29 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon return existing, nil } -func (s *PaymentConfigService) decryptConfig(encrypted string) (map[string]string, error) { - if encrypted == "" { +// decryptConfig parses a stored provider config. +// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext +// ("iv:authTag:ciphertext"). Values that cannot be parsed as either — including +// legacy ciphertext with no/invalid TOTP_ENCRYPTION_KEY — are treated as empty, +// letting the admin re-enter the config via the UI to complete the migration. +func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string, error) { + if stored == "" { return nil, nil } - decrypted, err := payment.Decrypt(encrypted, s.encryptionKey) - if err != nil { - return nil, fmt.Errorf("decrypt config: %w", err) + var cfg map[string]string + if err := json.Unmarshal([]byte(stored), &cfg); err == nil { + return cfg, nil } - var raw map[string]string - if err := json.Unmarshal([]byte(decrypted), &raw); err != nil { - return nil, fmt.Errorf("unmarshal decrypted config: %w", err) + if len(s.encryptionKey) == payment.AES256KeySize { + if plaintext, err := payment.Decrypt(stored, s.encryptionKey); err == nil { + if err := json.Unmarshal([]byte(plaintext), &cfg); err == nil { + return cfg, nil + } + } } - return raw, nil + slog.Warn("payment provider config unreadable, treating as empty for re-entry", + "stored_len", len(stored)) + return nil, nil } func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error { @@ -317,14 +328,13 @@ func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id in return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx) } +// encryptConfig serialises a provider config for storage. +// New records are written as plaintext JSON; the historical AES-GCM wrapping +// has been dropped but decryptConfig still accepts old ciphertext during migration. func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) { data, err := json.Marshal(cfg) if err != nil { return "", fmt.Errorf("marshal config: %w", err) } - enc, err := payment.Encrypt(string(data), s.encryptionKey) - if err != nil { - return "", fmt.Errorf("encrypt config: %w", err) - } - return enc, nil + return string(data), nil } From 44cdef7934168f167c5d433e5947aea3ac5a279d Mon Sep 17 00:00:00 2001 From: erio Date: Fri, 17 Apr 2026 17:00:45 +0800 Subject: [PATCH 002/326] fix(usage): subscription billing honours group rate multiplier MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Subscription-mode billing was consuming quota at TotalCost (raw) instead of ActualCost (TotalCost * RateMultiplier), so per-group rate multipliers — including free subscriptions (multiplier = 0) — were silently ignored. Switch the three subscription cost writes in buildUsageBillingCommand, finalizePostUsageBilling, and the legacy postUsageBilling fallback to ActualCost, and add a table-driven test covering 2x / 0.5x / free multipliers plus a balance-mode regression check. --- backend/internal/service/gateway_service.go | 16 ++-- ...teway_service_subscription_billing_test.go | 85 +++++++++++++++++++ 2 files changed, 96 insertions(+), 5 deletions(-) create mode 100644 backend/internal/service/gateway_service_subscription_billing_test.go diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 4b4fc0bf..07a9e41c 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -7317,8 +7317,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill cost := p.Cost if p.IsSubscriptionBill { - if cost.TotalCost > 0 { - if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil { + // Subscription usage tracked by ActualCost so group rate multiplier + // consumes the quota at the expected speed. + if cost.ActualCost > 0 { + if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.ActualCost); err != nil { slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) } } @@ -7417,9 +7419,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage } } + // Record subscription / balance cost using ActualCost so the group (and any + // user-specific) rate multiplier consumes subscription quota at the expected + // speed. TotalCost remains the raw (pre-multiplier) value; downstream guards + // on "> 0" still correctly skip free subscriptions (RateMultiplier == 0). if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 { cmd.SubscriptionID = &p.Subscription.ID - cmd.SubscriptionCost = p.Cost.TotalCost + cmd.SubscriptionCost = p.Cost.ActualCost } else if p.Cost.ActualCost > 0 { cmd.BalanceCost = p.Cost.ActualCost } @@ -7478,8 +7484,8 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, resu } if p.IsSubscriptionBill { - if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil { - deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost) + if p.Cost.ActualCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil { + deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.ActualCost) } } else if p.Cost.ActualCost > 0 && p.User != nil { deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost) diff --git a/backend/internal/service/gateway_service_subscription_billing_test.go b/backend/internal/service/gateway_service_subscription_billing_test.go new file mode 100644 index 00000000..42a81035 --- /dev/null +++ b/backend/internal/service/gateway_service_subscription_billing_test.go @@ -0,0 +1,85 @@ +//go:build unit + +package service + +import ( + "testing" +) + +// TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier locks in the fix +// that subscription-mode billing honours the group (and any user-specific) rate +// multiplier — i.e. cmd.SubscriptionCost tracks ActualCost (= TotalCost * +// RateMultiplier), not raw TotalCost. +func TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier(t *testing.T) { + t.Parallel() + + groupID := int64(7) + subID := int64(42) + + tests := []struct { + name string + totalCost float64 + actualCost float64 + isSubscription bool + wantSub float64 + wantBalance float64 + }{ + { + name: "subscription with 2x multiplier consumes 2x quota", + totalCost: 1.0, + actualCost: 2.0, + isSubscription: true, + wantSub: 2.0, + wantBalance: 0, + }, + { + name: "subscription with 0.5x multiplier consumes 0.5x quota", + totalCost: 1.0, + actualCost: 0.5, + isSubscription: true, + wantSub: 0.5, + wantBalance: 0, + }, + { + name: "free subscription (multiplier 0) consumes no quota", + totalCost: 1.0, + actualCost: 0, + isSubscription: true, + wantSub: 0, + wantBalance: 0, + }, + { + name: "balance billing keeps using ActualCost (regression)", + totalCost: 1.0, + actualCost: 2.0, + isSubscription: false, + wantSub: 0, + wantBalance: 2.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + p := &postUsageBillingParams{ + Cost: &CostBreakdown{TotalCost: tt.totalCost, ActualCost: tt.actualCost}, + User: &User{ID: 1}, + APIKey: &APIKey{ID: 2, GroupID: &groupID}, + Account: &Account{ID: 3}, + Subscription: &UserSubscription{ID: subID}, + IsSubscriptionBill: tt.isSubscription, + } + + cmd := buildUsageBillingCommand("req-1", nil, p) + if cmd == nil { + t.Fatal("buildUsageBillingCommand returned nil") + } + if cmd.SubscriptionCost != tt.wantSub { + t.Errorf("SubscriptionCost = %v, want %v", cmd.SubscriptionCost, tt.wantSub) + } + if cmd.BalanceCost != tt.wantBalance { + t.Errorf("BalanceCost = %v, want %v", cmd.BalanceCost, tt.wantBalance) + } + }) + } +} From 948d8e6d024412cc2efda34ec41540fddb4bfb0b Mon Sep 17 00:00:00 2001 From: erio Date: Fri, 17 Apr 2026 17:01:01 +0800 Subject: [PATCH 003/326] fix(admin): prevent browser password manager from autofilling account API key Chrome's password manager matched the apikey-type account's Base URL + API Key inputs as a login form and autofilled the last saved password by domain, so editing a Gemini account could overwrite its apikey with a Claude key that shared the same Base URL. Add autocomplete="new-password" plus data-*-ignore attributes for 1Password / LastPass / Bitwarden to opt the field out of every major password manager's autofill. --- frontend/src/components/account/EditAccountModal.vue | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 1da32e2c..59ca0b9c 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -52,6 +52,10 @@ v-model="editApiKey" type="password" class="input font-mono" + autocomplete="new-password" + data-1p-ignore + data-lpignore="true" + data-bwignore="true" :placeholder=" account.platform === 'openai' ? 'sk-proj-...' From df57d2776b1a74e37470970f1bcf8942ad810e1d Mon Sep 17 00:00:00 2001 From: erio Date: Fri, 17 Apr 2026 18:32:12 +0800 Subject: [PATCH 004/326] fix(billing): reject rate_multiplier <= 0 on save; clamp negatives to 0 in compute MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 分组倍率和用户专属倍率在保存时没有校验,0 会触发计费层的 `<=0 → 1.0` 防御条款,结果订阅/余额分组按标准价扣费;完全是沉默地绕过了业务规则。 - 保存校验(admin_service):CreateGroup / UpdateGroup / BatchSetGroupRateMultipliers / UpdateUser.SyncUserGroupRates 全部要求 > 0 - 计算层(billing_service):三处 `<=0 → 1.0` 改为 `<0 → 0`;负数按 0 结算, 避免配置异常被静默按 1x 收费 - 前端:分组倍率 / 用户专属倍率输入 min 统一到 0.001 - 删除未使用的 IsFreeSubscription 方法 测试:新增 billing_service_rate_multiplier_test.go 端到端验证;更新原有锁定 旧 `<=0 → 1.0` 行为的测试。 --- backend/internal/service/admin_service.go | 21 +++++++ .../service/admin_service_group_test.go | 6 ++ backend/internal/service/billing_service.go | 16 ++--- .../service/billing_service_image_test.go | 5 +- .../billing_service_rate_multiplier_test.go | 63 +++++++++++++++++++ .../internal/service/billing_service_test.go | 28 --------- .../service/billing_service_unified_test.go | 40 ++++-------- backend/internal/service/group.go | 4 -- .../openai_gateway_record_usage_test.go | 2 +- .../admin/group/GroupRateMultipliersModal.vue | 2 +- .../admin/user/UserAllowedGroupsModal.vue | 4 +- 11 files changed, 119 insertions(+), 72 deletions(-) create mode 100644 backend/internal/service/billing_service_rate_multiplier_test.go diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 7c26a47c..701f3659 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -586,6 +586,15 @@ func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userI } func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { + // 校验用户专属分组倍率:必须 > 0(nil 合法,表示清除专属倍率) + if input.GroupRates != nil { + for groupID, rate := range input.GroupRates { + if rate != nil && *rate <= 0 { + return nil, fmt.Errorf("rate_multiplier must be > 0 (group_id=%d)", groupID) + } + } + } + user, err := s.userRepo.GetByID(ctx, id) if err != nil { return nil, err @@ -811,6 +820,10 @@ func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, erro } func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) { + if input.RateMultiplier <= 0 { + return nil, errors.New("rate_multiplier must be > 0") + } + platform := input.Platform if platform == "" { platform = PlatformAnthropic @@ -1050,6 +1063,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd group.Platform = input.Platform } if input.RateMultiplier != nil { + if *input.RateMultiplier <= 0 { + return nil, errors.New("rate_multiplier must be > 0") + } group.RateMultiplier = *input.RateMultiplier } if input.IsExclusive != nil { @@ -1286,6 +1302,11 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro if s.userGroupRateRepo == nil { return nil } + for _, e := range entries { + if e.RateMultiplier <= 0 { + return fmt.Errorf("rate_multiplier must be > 0 (user_id=%d)", e.UserID) + } + } return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries) } diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index a4c6d0ca..41d2c26a 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -621,6 +621,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatfo _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformOpenAI, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -641,6 +642,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *t _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeSubscription, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -695,6 +697,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t * _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -713,6 +716,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) { _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -733,6 +737,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAntigravity, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -750,6 +755,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing. group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &zero, }) diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 32a54cbe..c9f32b3b 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -448,8 +448,9 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, }) } - if input.RateMultiplier <= 0 { - input.RateMultiplier = 1.0 + // 保存时强制 > 0;若仍有负数泄漏(缓存/迁移残留),按 0 处理避免按 1x 误扣。 + if input.RateMultiplier < 0 { + input.RateMultiplier = 0 } var breakdown *CostBreakdown @@ -493,8 +494,9 @@ func (s *BillingService) computeTokenBreakdown( rateMultiplier float64, serviceTier string, applyLongCtx bool, ) *CostBreakdown { - if rateMultiplier <= 0 { - rateMultiplier = 1.0 + // 保存时强制 > 0;若仍有负数泄漏,按 0 处理避免按 1x 误扣。 + if rateMultiplier < 0 { + rateMultiplier = 0 } inputPrice := pricing.InputPricePerToken @@ -831,9 +833,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag // 计算总费用 totalCost := unitPrice * float64(imageCount) - // 应用倍率 - if rateMultiplier <= 0 { - rateMultiplier = 1.0 + // 应用倍率(保存时强制 > 0;负数按 0 处理避免按 1x 误扣) + if rateMultiplier < 0 { + rateMultiplier = 0 } actualCost := totalCost * rateMultiplier diff --git a/backend/internal/service/billing_service_image_test.go b/backend/internal/service/billing_service_image_test.go index fa90f6bb..8d3ca987 100644 --- a/backend/internal/service/billing_service_image_test.go +++ b/backend/internal/service/billing_service_image_test.go @@ -90,13 +90,14 @@ func TestCalculateImageCost_NegativeCount(t *testing.T) { require.Equal(t, 0.0, cost.ActualCost) } -// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0 +// TestCalculateImageCost_ZeroRateMultiplier 锁定新行为:倍率 0 直接按 0 计费 +// (保存时已强制 > 0;若仍有 0 泄漏到计费层,零消耗比历史的 1.0 更安全)。 func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) { svc := &BillingService{} cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0) require.InDelta(t, 0.201, cost.TotalCost, 0.0001) - require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理 + require.InDelta(t, 0.0, cost.ActualCost, 1e-10) } // TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格 diff --git a/backend/internal/service/billing_service_rate_multiplier_test.go b/backend/internal/service/billing_service_rate_multiplier_test.go new file mode 100644 index 00000000..83788196 --- /dev/null +++ b/backend/internal/service/billing_service_rate_multiplier_test.go @@ -0,0 +1,63 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestCalculateCost_RateMultiplier_NegativeClampedToZero 锁定负数倍率被 +// 钳制为 0(而非历史上的 1.0),避免配置异常导致静默按标准价扣费。 +func TestCalculateCost_RateMultiplier_NegativeClampedToZero(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + + tests := []struct { + name string + multiplier float64 + wantRatio float64 // ActualCost / TotalCost + }{ + {"negative clamped to 0", -1.5, 0}, + {"zero passes through as 0 (defense in depth)", 0, 0}, + {"positive 2x applied", 2.0, 2.0}, + {"positive 0.5x applied", 0.5, 0.5}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, tt.multiplier) + require.NoError(t, err) + require.Greater(t, cost.TotalCost, 0.0, "TotalCost should be non-zero") + require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9) + }) + } +} + +// TestCalculateImageCost_RateMultiplier_NegativeClampedToZero 图片按次计费路径 +// 同样遵循"负数 → 0"语义。 +func TestCalculateImageCost_RateMultiplier_NegativeClampedToZero(t *testing.T) { + svc := newTestBillingService() + price := 0.04 + cfg := &ImagePriceConfig{Price1K: &price} + + tests := []struct { + name string + multiplier float64 + wantRatio float64 + }{ + {"negative clamped to 0", -0.5, 0}, + {"zero passes through", 0, 0}, + {"positive 3x applied", 3.0, 3.0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cost := svc.CalculateImageCost("imagen-3", "1K", 2, cfg, tt.multiplier) + require.NotNil(t, cost) + require.Greater(t, cost.TotalCost, 0.0) + require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9) + }) + } +} diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 2cf134e2..fc8361c7 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -71,34 +71,6 @@ func TestCalculateCost_RateMultiplier(t *testing.T) { require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10) } -func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) { - svc := newTestBillingService() - - tokens := UsageTokens{InputTokens: 1000} - - costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0) - require.NoError(t, err) - - costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) - require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) -} - -func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) { - svc := newTestBillingService() - - tokens := UsageTokens{InputTokens: 1000} - - costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0) - require.NoError(t, err) - - costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) - require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) -} - func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) { svc := newTestBillingService() diff --git a/backend/internal/service/billing_service_unified_test.go b/backend/internal/service/billing_service_unified_test.go index 694c3384..e6a92d1a 100644 --- a/backend/internal/service/billing_service_unified_test.go +++ b/backend/internal/service/billing_service_unified_test.go @@ -147,40 +147,35 @@ func TestCalculateCostUnified_ImageMode(t *testing.T) { require.Equal(t, string(BillingModeImage), cost.BillingMode) } -func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) { +// TestCalculateCostUnified_RateMultiplierZeroProducesZero 锁定新行为: +// 保存时强制 > 0;若 0 仍泄漏到计费层,按 0 计费(而非历史上的 1.0)。 +func TestCalculateCostUnified_RateMultiplierZeroProducesZero(t *testing.T) { bs := newTestBillingService() resolver := NewModelPricingResolver(nil, bs) tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} - costZero, err := bs.CalculateCostUnified(CostInput{ + cost, err := bs.CalculateCostUnified(CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", Tokens: tokens, - RateMultiplier: 0, // should default to 1.0 + RateMultiplier: 0, Resolver: resolver, }) require.NoError(t, err) - - costOne, err := bs.CalculateCostUnified(CostInput{ - Ctx: context.Background(), - Model: "claude-sonnet-4", - Tokens: tokens, - RateMultiplier: 1.0, - Resolver: resolver, - }) - require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) + require.Greater(t, cost.TotalCost, 0.0) + require.InDelta(t, 0.0, cost.ActualCost, 1e-10) } -func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) { +// TestCalculateCostUnified_NegativeRateMultiplierClampedToZero 锁定新行为: +// 负数倍率按 0 计费,避免历史的 <=0 → 1.0 把配置异常静默按标准价扣费。 +func TestCalculateCostUnified_NegativeRateMultiplierClampedToZero(t *testing.T) { bs := newTestBillingService() resolver := NewModelPricingResolver(nil, bs) tokens := UsageTokens{InputTokens: 1000} - costNeg, err := bs.CalculateCostUnified(CostInput{ + cost, err := bs.CalculateCostUnified(CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", Tokens: tokens, @@ -188,17 +183,8 @@ func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) Resolver: resolver, }) require.NoError(t, err) - - costOne, err := bs.CalculateCostUnified(CostInput{ - Ctx: context.Background(), - Model: "claude-sonnet-4", - Tokens: tokens, - RateMultiplier: 1.0, - Resolver: resolver, - }) - require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) + require.Greater(t, cost.TotalCost, 0.0) + require.InDelta(t, 0.0, cost.ActualCost, 1e-10) } func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) { diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 12262613..64434ae1 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -76,10 +76,6 @@ func (g *Group) IsSubscriptionType() bool { return g.SubscriptionType == SubscriptionTypeSubscription } -func (g *Group) IsFreeSubscription() bool { - return g.IsSubscriptionType() && g.RateMultiplier == 0 -} - func (g *Group) HasDailyLimit() bool { return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0 } diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index e6fa94aa..6fa8a5bd 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -1031,7 +1031,7 @@ func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFiel Model: "gpt-5.1", Duration: time.Second, }, - APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}}, + APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription, RateMultiplier: 1.0}}, User: &User{ID: 200}, Account: &Account{ID: 300}, Subscription: subscription, diff --git a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue index bf79bea2..41b2e63c 100644 --- a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue +++ b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue @@ -166,7 +166,7 @@ Date: Fri, 17 Apr 2026 22:07:15 +0800 Subject: [PATCH 005/326] feat(gateway): raise upstream response read limit 8MB -> 128MB (configurable) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 图片生成 API 返回的 base64 内联图响应经常超过 8MB 单次读取上限,被 ReadUpstreamResponseBody 拦截成 502 upstream_error。 单张 4K PNG base64 最坏约 67MB,多张候选图或 imageSize=4K 的 image_generation 一次请求能轻松到 30MB+。把默认上限提到 128MB 能覆盖 2-3 张 4K 图,相对 请求体上限 256MB 仍有缓冲;同时抽出 config.DefaultUpstreamResponseReadMaxBytes 共享常量,viper 默认值和 service 层兜底共用,消除两处同步魔法数字。 仍可通过 gateway.upstream_response_read_max_bytes 配置项覆盖。 --- backend/internal/config/config.go | 7 ++++++- backend/internal/service/upstream_response_limit.go | 4 +++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index dd9a4e58..15592905 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -52,6 +52,11 @@ const ( ConnectionPoolIsolationAccountProxy = "account_proxy" ) +// DefaultUpstreamResponseReadMaxBytes 上游非流式响应体的默认读取上限。 +// 128 MB 可容纳 2-3 张 4K PNG(base64 膨胀 33%,单张 4K PNG 最坏约 67MB base64)。 +// 可通过 gateway.upstream_response_read_max_bytes 配置项覆盖。 +const DefaultUpstreamResponseReadMaxBytes int64 = 128 * 1024 * 1024 + type Config struct { Server ServerConfig `mapstructure:"server"` Log LogConfig `mapstructure:"log"` @@ -1407,7 +1412,7 @@ func setDefaults() { viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.antigravity_extra_retries", 10) viper.SetDefault("gateway.max_body_size", int64(256*1024*1024)) - viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) + viper.SetDefault("gateway.upstream_response_read_max_bytes", DefaultUpstreamResponseReadMaxBytes) viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) viper.SetDefault("gateway.gemini_debug_response_headers", false) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) diff --git a/backend/internal/service/upstream_response_limit.go b/backend/internal/service/upstream_response_limit.go index a0444d52..ddf0e818 100644 --- a/backend/internal/service/upstream_response_limit.go +++ b/backend/internal/service/upstream_response_limit.go @@ -12,7 +12,9 @@ import ( var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large") -const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024 +// defaultUpstreamResponseReadMaxBytes 源自 config.DefaultUpstreamResponseReadMaxBytes, +// 仅在 cfg 为 nil 时作为兜底(测试或极端场景)。 +const defaultUpstreamResponseReadMaxBytes = config.DefaultUpstreamResponseReadMaxBytes func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 { if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 { From 61a008f7e4e0e019f00e63409389aa5aa1058b0a Mon Sep 17 00:00:00 2001 From: erio Date: Fri, 17 Apr 2026 23:05:58 +0800 Subject: [PATCH 006/326] chore(payment): mark legacy AES ciphertext fallback as deprecated MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 明文 JSON 已经是新写入的默认格式;保留 AES 密文读取仅为兼容迁移期间的旧 记录,一旦所有部署通过管理后台重存过一次即可删除。标记为 deprecated 并加 TODO,几个版本后统一清理掉:payment.Encrypt / payment.Decrypt、两处 decryptConfig 的 AES 分支、PaymentConfigService.encryptionKey 和 DefaultLoadBalancer.encryptionKey 字段。 --- backend/internal/payment/crypto.go | 10 ++++++++++ backend/internal/payment/load_balancer.go | 7 +++++++ backend/internal/service/payment_config_providers.go | 6 ++++++ 3 files changed, 23 insertions(+) diff --git a/backend/internal/payment/crypto.go b/backend/internal/payment/crypto.go index 5467e50b..0581469d 100644 --- a/backend/internal/payment/crypto.go +++ b/backend/internal/payment/crypto.go @@ -16,6 +16,11 @@ const AES256KeySize = 32 // Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key. // The output format is "iv:authTag:ciphertext" where each component is base64-encoded, // matching the Node.js crypto.ts format for cross-compatibility. +// +// Deprecated: payment provider configs are now stored as plaintext JSON. +// This function is kept only for seeding legacy ciphertext in tests and for +// the transitional Decrypt fallback. Scheduled for removal after all live +// deployments complete migration by re-saving their configs. func Encrypt(plaintext string, key []byte) (string, error) { if len(key) != AES256KeySize { return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key)) @@ -54,6 +59,11 @@ func Encrypt(plaintext string, key []byte) (string, error) { // Decrypt decrypts a ciphertext string produced by Encrypt. // The input format is "iv:authTag:ciphertext" where each component is base64-encoded. +// +// Deprecated: payment provider configs are now stored as plaintext JSON. +// This function remains only as a read-path fallback for pre-migration +// ciphertext records. Scheduled for removal once all deployments re-save +// their provider configs through the admin UI. func Decrypt(ciphertext string, key []byte) (string, error) { if len(key) != AES256KeySize { return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key)) diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go index 52a1b011..ec244cd6 100644 --- a/backend/internal/payment/load_balancer.go +++ b/backend/internal/payment/load_balancer.go @@ -283,6 +283,11 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns // Unreadable values (legacy ciphertext without a valid key, or malformed data) // are treated as empty so the service keeps running while the admin re-enters // the config via the UI. +// +// TODO(deprecated-legacy-ciphertext): The AES fallback branch below is a +// transitional compatibility shim for pre-plaintext records. Remove it (and +// the encryptionKey field + the Decrypt import) after a few releases once all +// live deployments have re-saved their provider configs through the UI. func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string, error) { if stored == "" { return nil, nil @@ -291,7 +296,9 @@ func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string, if err := json.Unmarshal([]byte(stored), &config); err == nil { return config, nil } + // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal. if len(lb.encryptionKey) == AES256KeySize { + //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal if plaintext, err := Decrypt(stored, lb.encryptionKey); err == nil { if err := json.Unmarshal([]byte(plaintext), &config); err == nil { return config, nil diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 59337ad6..8e470525 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -296,6 +296,10 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon // ("iv:authTag:ciphertext"). Values that cannot be parsed as either — including // legacy ciphertext with no/invalid TOTP_ENCRYPTION_KEY — are treated as empty, // letting the admin re-enter the config via the UI to complete the migration. +// +// TODO(deprecated-legacy-ciphertext): The AES fallback branch is a transitional +// shim for pre-plaintext records. Remove it (and the encryptionKey field) after +// a few releases once all live deployments have re-saved their provider configs. func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string, error) { if stored == "" { return nil, nil @@ -304,7 +308,9 @@ func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string, if err := json.Unmarshal([]byte(stored), &cfg); err == nil { return cfg, nil } + // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal. if len(s.encryptionKey) == payment.AES256KeySize { + //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal if plaintext, err := payment.Decrypt(stored, s.encryptionKey); err == nil { if err := json.Unmarshal([]byte(plaintext), &cfg); err == nil { return cfg, nil From c3cb0280ef8dff949ad7ad8f84793872b9304686 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 19 Apr 2026 01:40:25 +0800 Subject: [PATCH 007/326] fix(payment): alipay redirect-only flow, H5 detection and popup sizing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The native Alipay provider previously tried to embed the payment page URL into a QR code on the client — the URL is not a scannable payload so the QR never worked. Merchants also hit a H5 detection mismatch whenever the backend UA sniffer missed iPadOS 13+ or embedded browsers, and the popup window was too small for Alipay's standard checkout layout (QR + account-login panel on the right), forcing the user to scroll horizontally and vertically. Changes: Backend - alipay.go: drop QR-on-URL path. Use redirect-only flow — alipay.trade.page.pay for PC (returns a gateway URL the browser opens in a new window) and alipay.trade.wap.pay for H5 (returns a URL the browser jumps to). Both flows produce pages on openapi.alipaydev.com / excashier.alipay.com; the client never renders a QR itself. - payment_handler.go: add optional is_mobile bool to CreateOrderRequest so the frontend can declare the device explicitly. Server still falls back to UA sniffing when absent. Frontend - types/payment.ts, PaymentView.vue: declare is_mobile in CreateOrderRequest and pass the computed isMobileDevice() value. - providerConfig.ts: replace the two fixed POPUP_WINDOW_FEATURES constants with getPaymentPopupFeatures(), which prefers 1250×900 (Alipay's checkout footprint), clamps to window.screen.avail* and centers the popup so it never overflows on smaller laptops. - PaymentQRDialog.vue, PaymentStatusPanel.vue, StripePaymentInline.vue, PaymentView.vue: use the new helper at all popup call sites. --- backend/internal/handler/payment_handler.go | 10 +++- backend/internal/payment/provider/alipay.go | 48 ++++++++++--------- .../components/payment/PaymentQRDialog.vue | 4 +- .../components/payment/PaymentStatusPanel.vue | 4 +- .../payment/StripePaymentInline.vue | 4 +- .../src/components/payment/providerConfig.ts | 21 ++++++-- frontend/src/types/payment.ts | 1 + frontend/src/views/user/PaymentView.vue | 5 +- 8 files changed, 62 insertions(+), 35 deletions(-) diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index 1ddb8ae2..854dca54 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -206,6 +206,10 @@ type CreateOrderRequest struct { PaymentType string `json:"payment_type" binding:"required"` OrderType string `json:"order_type"` PlanID int64 `json:"plan_id"` + // IsMobile lets the frontend declare its mobile status directly. When + // nil we fall back to User-Agent heuristics (which miss iPadOS / some + // embedded browsers that strip the "Mobile" keyword). + IsMobile *bool `json:"is_mobile,omitempty"` } // CreateOrder creates a new payment order. @@ -222,12 +226,16 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) { return } + mobile := isMobile(c) + if req.IsMobile != nil { + mobile = *req.IsMobile + } result, err := h.paymentService.CreateOrder(c.Request.Context(), service.CreateOrderRequest{ UserID: subject.UserID, Amount: req.Amount, PaymentType: req.PaymentType, ClientIP: c.ClientIP(), - IsMobile: isMobile(c), + IsMobile: mobile, SrcHost: c.Request.Host, SrcURL: c.Request.Referer(), OrderType: req.OrderType, diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go index af8a90c6..fe8ea89c 100644 --- a/backend/internal/payment/provider/alipay.go +++ b/backend/internal/payment/provider/alipay.go @@ -15,8 +15,8 @@ import ( // Alipay product codes. const ( - alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY" alipayProductCodeWapPay = "QUICK_WAP_WAY" + alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY" ) // Alipay response constants. @@ -79,7 +79,12 @@ func (a *Alipay) SupportedTypes() []payment.PaymentType { return []payment.PaymentType{payment.TypeAlipay} } -// CreatePayment creates an Alipay payment page URL. +// CreatePayment creates an Alipay payment using redirect-only flow: +// - Mobile (H5): alipay.trade.wap.pay — returns a URL the browser jumps to. +// - PC: alipay.trade.page.pay — returns a gateway URL the browser opens in a +// new window; Alipay's own page then shows login/QR. We intentionally do +// NOT encode the URL into a QR on the client (it isn't a scannable payload +// and would produce an invalid scan result). func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { client, err := a.getClient() if err != nil { @@ -96,31 +101,31 @@ func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentReque } if req.IsMobile { - return a.createTrade(client, req, notifyURL, returnURL, true) + return a.createWapTrade(client, req, notifyURL, returnURL) } - return a.createTrade(client, req, notifyURL, returnURL, false) + return a.createPagePayTrade(client, req, notifyURL, returnURL) } -func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string, isMobile bool) (*payment.CreatePaymentResponse, error) { - if isMobile { - param := alipay.TradeWapPay{} - param.OutTradeNo = req.OrderID - param.TotalAmount = req.Amount - param.Subject = req.Subject - param.ProductCode = alipayProductCodeWapPay - param.NotifyURL = notifyURL - param.ReturnURL = returnURL +func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) { + param := alipay.TradeWapPay{} + param.OutTradeNo = req.OrderID + param.TotalAmount = req.Amount + param.Subject = req.Subject + param.ProductCode = alipayProductCodeWapPay + param.NotifyURL = notifyURL + param.ReturnURL = returnURL - payURL, err := client.TradeWapPay(param) - if err != nil { - return nil, fmt.Errorf("alipay TradeWapPay: %w", err) - } - return &payment.CreatePaymentResponse{ - TradeNo: req.OrderID, - PayURL: payURL.String(), - }, nil + payURL, err := client.TradeWapPay(param) + if err != nil { + return nil, fmt.Errorf("alipay TradeWapPay: %w", err) } + return &payment.CreatePaymentResponse{ + TradeNo: req.OrderID, + PayURL: payURL.String(), + }, nil +} +func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) { param := alipay.TradePagePay{} param.OutTradeNo = req.OrderID param.TotalAmount = req.Amount @@ -136,7 +141,6 @@ func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentReq return &payment.CreatePaymentResponse{ TradeNo: req.OrderID, PayURL: payURL.String(), - QRCode: payURL.String(), }, nil } diff --git a/frontend/src/components/payment/PaymentQRDialog.vue b/frontend/src/components/payment/PaymentQRDialog.vue index b9026e78..db90c3b6 100644 --- a/frontend/src/components/payment/PaymentQRDialog.vue +++ b/frontend/src/components/payment/PaymentQRDialog.vue @@ -79,7 +79,7 @@ import { usePaymentStore } from '@/stores/payment' import { useAppStore } from '@/stores' import { paymentAPI } from '@/api/payment' import { extractApiErrorMessage } from '@/utils/apiError' -import { POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig' +import { getPaymentPopupFeatures } from '@/components/payment/providerConfig' import type { PaymentOrder } from '@/types/payment' import QRCode from 'qrcode' import alipayIcon from '@/assets/icons/alipay.svg' @@ -147,7 +147,7 @@ function getLogoForType(): string | null { function reopenPopup() { if (props.payUrl) { - window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES) + window.open(props.payUrl, 'paymentPopup', getPaymentPopupFeatures()) } } diff --git a/frontend/src/components/payment/PaymentStatusPanel.vue b/frontend/src/components/payment/PaymentStatusPanel.vue index 974dee66..17541e59 100644 --- a/frontend/src/components/payment/PaymentStatusPanel.vue +++ b/frontend/src/components/payment/PaymentStatusPanel.vue @@ -125,7 +125,7 @@ import { usePaymentStore } from '@/stores/payment' import { useAppStore } from '@/stores' import { paymentAPI } from '@/api/payment' import { extractApiErrorMessage } from '@/utils/apiError' -import { POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig' +import { getPaymentPopupFeatures } from '@/components/payment/providerConfig' import type { PaymentOrder } from '@/types/payment' import Icon from '@/components/icons/Icon.vue' import QRCode from 'qrcode' @@ -194,7 +194,7 @@ const countdownDisplay = computed(() => { function reopenPopup() { if (props.payUrl) { - window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES) + window.open(props.payUrl, 'paymentPopup', getPaymentPopupFeatures()) } } diff --git a/frontend/src/components/payment/StripePaymentInline.vue b/frontend/src/components/payment/StripePaymentInline.vue index b8fd55ef..3ddff8c8 100644 --- a/frontend/src/components/payment/StripePaymentInline.vue +++ b/frontend/src/components/payment/StripePaymentInline.vue @@ -70,7 +70,7 @@ import { useRouter } from 'vue-router' import { extractApiErrorMessage } from '@/utils/apiError' import { paymentAPI } from '@/api/payment' import { useAppStore } from '@/stores' -import { STRIPE_POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig' +import { getPaymentPopupFeatures } from '@/components/payment/providerConfig' import type { Stripe, StripeElements } from '@stripe/stripe-js' import Icon from '@/components/icons/Icon.vue' @@ -151,7 +151,7 @@ async function handlePay() { amount: String(props.payAmount), }, }).href - const popup = window.open(popupUrl, 'paymentPopup', STRIPE_POPUP_WINDOW_FEATURES) + const popup = window.open(popupUrl, 'paymentPopup', getPaymentPopupFeatures()) const onReady = (event: MessageEvent) => { if (event.source !== popup || event.data?.type !== 'STRIPE_POPUP_READY') return diff --git a/frontend/src/components/payment/providerConfig.ts b/frontend/src/components/payment/providerConfig.ts index a83787fd..bf2d4177 100644 --- a/frontend/src/components/payment/providerConfig.ts +++ b/frontend/src/components/payment/providerConfig.ts @@ -43,11 +43,24 @@ export const METHOD_ORDER = ['alipay', 'alipay_direct', 'wxpay', 'wxpay_direct', export const PAYMENT_MODE_QRCODE = 'qrcode' export const PAYMENT_MODE_POPUP = 'popup' -/** Window features for payment popup windows */ -export const POPUP_WINDOW_FEATURES = 'width=1000,height=750,left=100,top=80,scrollbars=yes,resizable=yes' +/** Preferred popup size for payment gateways. Alipay's standard checkout + * (QR + account login panel) needs ~1200×900 to render without any scrolling. */ +const PAYMENT_POPUP_PREFERRED_WIDTH = 1250 +const PAYMENT_POPUP_PREFERRED_HEIGHT = 900 -/** Wider popup for Stripe redirect methods (Alipay checkout page needs ~1200px) */ -export const STRIPE_POPUP_WINDOW_FEATURES = 'width=1250,height=780,left=80,top=60,scrollbars=yes,resizable=yes' +/** Build a window.open features string sized to fit within the current screen + * while preferring the above dimensions. Centers the popup on the available + * work area so nothing is clipped on smaller laptop displays. */ +export function getPaymentPopupFeatures(): string { + const screen = typeof window !== 'undefined' ? window.screen : null + const availW = screen?.availWidth ?? PAYMENT_POPUP_PREFERRED_WIDTH + const availH = screen?.availHeight ?? PAYMENT_POPUP_PREFERRED_HEIGHT + const width = Math.min(PAYMENT_POPUP_PREFERRED_WIDTH, availW - 40) + const height = Math.min(PAYMENT_POPUP_PREFERRED_HEIGHT, availH - 40) + const left = Math.max(0, Math.floor((availW - width) / 2)) + const top = Math.max(0, Math.floor((availH - height) / 2)) + return `width=${width},height=${height},left=${left},top=${top},scrollbars=yes,resizable=yes` +} /** Webhook paths for each provider (relative to origin). */ export const WEBHOOK_PATHS: Record = { diff --git a/frontend/src/types/payment.ts b/frontend/src/types/payment.ts index 7ecbb9a9..6f2eec51 100644 --- a/frontend/src/types/payment.ts +++ b/frontend/src/types/payment.ts @@ -154,6 +154,7 @@ export interface CreateOrderRequest { payment_type: string order_type: string plan_id?: number + is_mobile?: boolean } export interface CreateOrderResult { diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue index e91df5da..3f1401b3 100644 --- a/frontend/src/views/user/PaymentView.vue +++ b/frontend/src/views/user/PaymentView.vue @@ -277,7 +277,7 @@ import type { SubscriptionPlan, CheckoutInfoResponse, OrderType } from '@/types/ import AppLayout from '@/components/layout/AppLayout.vue' import AmountInput from '@/components/payment/AmountInput.vue' import PaymentMethodSelector from '@/components/payment/PaymentMethodSelector.vue' -import { METHOD_ORDER, POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig' +import { METHOD_ORDER, getPaymentPopupFeatures } from '@/components/payment/providerConfig' import { platformAccentBarClass, platformBadgeLightClass, platformBadgeClass, platformTextClass, platformLabel } from '@/utils/platformColors' import SubscriptionPlanCard from '@/components/payment/SubscriptionPlanCard.vue' import PaymentStatusPanel from '@/components/payment/PaymentStatusPanel.vue' @@ -551,9 +551,10 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n payment_type: selectedMethod.value, order_type: orderType, plan_id: planId, + is_mobile: isMobileDevice(), }) const openWindow = (url: string) => { - const win = window.open(url, 'paymentPopup', POPUP_WINDOW_FEATURES) + const win = window.open(url, 'paymentPopup', getPaymentPopupFeatures()) if (!win || win.closed) { window.location.href = url } From 235f710853b3e44b5b2fa236cfefd0b9d0c2a044 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 19 Apr 2026 01:46:50 +0800 Subject: [PATCH 008/326] feat(payment): redact provider secrets in admin config API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Admin GET /api/v1/admin/payment/providers previously returned every config value — including privateKey / apiV3Key / secretKey etc. — verbatim. Any future XSS on the admin UI would hand attackers the full set of production payment credentials, and the plaintext values sat unnecessarily in browser memory for every operator. Treat those fields as write-only from the admin surface: - decryptAndMaskConfig() strips sensitive keys from the GET response. The authoritative list is an explicit per-provider registry that mirrors the frontend's PROVIDER_CONFIG_FIELDS sensitive flag: alipay → privateKey, publicKey, alipayPublicKey wxpay → privateKey, apiV3Key, publicKey stripe → secretKey, webhookSecret (publishableKey stays plain) easypay → pkey Payment runtime still reads the full config via decryptConfig, so nothing at the gateway changes. - mergeConfig() treats an empty value for a sensitive key as "leave unchanged" — the admin UI omits unchanged secrets so operators can tweak non-sensitive settings without re-entering credentials. - Admin dialog (PaymentProviderDialog.vue): * secret inputs get autocomplete="new-password", data-1p-ignore, data-lpignore and data-bwignore so password managers do not offer to save provider credentials * in edit mode the required-field check skips sensitive fields (empty is the "keep existing" signal) and the placeholder shows "leave empty to keep" instead of the default example value * create mode still requires every non-optional field, including secrets, since there is nothing to preserve - Unit test renamed to TestIsSensitiveProviderConfigField, covers the per-provider registry and specifically asserts that Stripe's publishableKey is NOT treated as a secret. --- .../service/payment_config_providers.go | 78 +++++++++++++++---- .../service/payment_config_providers_test.go | 61 +++++++++------ .../payment/PaymentProviderDialog.vue | 23 ++++-- 3 files changed, 118 insertions(+), 44 deletions(-) diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 8e470525..3d1e4dc4 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -52,7 +52,7 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte AllowUserRefund: inst.AllowUserRefund, SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, } - resp.Config, err = s.decryptAndMaskConfig(inst.Config) + resp.Config, err = s.decryptAndMaskConfig(inst.ProviderKey, inst.Config) if err != nil { return nil, fmt.Errorf("decrypt config for instance %d: %w", inst.ID, err) } @@ -61,8 +61,26 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte return result, nil } -func (s *PaymentConfigService) decryptAndMaskConfig(encrypted string) (map[string]string, error) { - return s.decryptConfig(encrypted) +// decryptAndMaskConfig returns the stored config with sensitive fields omitted. +// Admin UIs display masked placeholders for these; the raw values never leave +// the server. Callers that need the full config (e.g. payment runtime) must +// use decryptConfig directly. +func (s *PaymentConfigService) decryptAndMaskConfig(providerKey, encrypted string) (map[string]string, error) { + cfg, err := s.decryptConfig(encrypted) + if err != nil { + return nil, err + } + if cfg == nil { + return nil, nil + } + masked := make(map[string]string, len(cfg)) + for k, v := range cfg { + if isSensitiveProviderConfigField(providerKey, k) { + continue + } + masked[k] = v + } + return masked, nil } // pendingOrderStatuses are order statuses considered "in progress". @@ -72,16 +90,27 @@ var pendingOrderStatuses = []string{ payment.OrderStatusRecharging, } -var sensitiveConfigPatterns = []string{"key", "pkey", "secret", "private", "password"} +// providerSensitiveConfigFields is the authoritative list of config keys that +// are treated as secrets per provider. Must stay in sync with the frontend +// definition at frontend/src/components/payment/providerConfig.ts +// (PROVIDER_CONFIG_FIELDS, fields with sensitive: true). +// +// Key matching is case-insensitive. Non-listed keys (e.g. appId, notifyUrl, +// stripe publishableKey) are returned in plaintext by the admin GET API. +var providerSensitiveConfigFields = map[string]map[string]struct{}{ + payment.TypeEasyPay: {"pkey": {}}, + payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}}, + payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}}, + payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}}, +} -func isSensitiveConfigField(fieldName string) bool { - lower := strings.ToLower(fieldName) - for _, p := range sensitiveConfigPatterns { - if strings.Contains(lower, p) { - return true - } +func isSensitiveProviderConfigField(providerKey, fieldName string) bool { + fields, ok := providerSensitiveConfigFields[providerKey] + if !ok { + return false } - return false + _, found := fields[strings.ToLower(fieldName)] + return found } func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) { @@ -137,10 +166,26 @@ func validateProviderRequest(providerKey, name, supportedTypes string) error { // NOTE: This function exceeds 30 lines due to per-field nil-check patch update // boilerplate and pending-order safety checks. func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) { + var cachedInst *dbent.PaymentProviderInstance + loadInst := func() (*dbent.PaymentProviderInstance, error) { + if cachedInst != nil { + return cachedInst, nil + } + inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) + if err != nil { + return nil, fmt.Errorf("load provider instance: %w", err) + } + cachedInst = inst + return inst, nil + } if req.Config != nil { + inst, err := loadInst() + if err != nil { + return nil, err + } hasSensitive := false - for k := range req.Config { - if isSensitiveConfigField(k) && req.Config[k] != "" { + for k, v := range req.Config { + if v != "" && isSensitiveProviderConfigField(inst.ProviderKey, k) { hasSensitive = true break } @@ -283,9 +328,14 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon return nil, fmt.Errorf("decrypt existing config for instance %d: %w", id, err) } if existing == nil { - return newConfig, nil + existing = map[string]string{} } for k, v := range newConfig { + // Preserve existing secrets when the client submits an empty value + // (admin UI omits the value to indicate "leave unchanged"). + if v == "" && isSensitiveProviderConfigField(inst.ProviderKey, k) { + continue + } existing[k] = v } return existing, nil diff --git a/backend/internal/service/payment_config_providers_test.go b/backend/internal/service/payment_config_providers_test.go index 2aaa874f..bc2a9b18 100644 --- a/backend/internal/service/payment_config_providers_test.go +++ b/backend/internal/service/payment_config_providers_test.go @@ -97,41 +97,52 @@ func TestValidateProviderRequest(t *testing.T) { } } -func TestIsSensitiveConfigField(t *testing.T) { +func TestIsSensitiveProviderConfigField(t *testing.T) { t.Parallel() tests := []struct { - field string - wantSen bool + providerKey string + field string + wantSen bool }{ - // Sensitive fields (contain key/secret/private/password/pkey patterns) - {"secretKey", true}, - {"apiSecret", true}, - {"pkey", true}, - {"privateKey", true}, - {"apiPassword", true}, - {"appKey", true}, - {"SECRET_TOKEN", true}, - {"PrivateData", true}, - {"PASSWORD", true}, - {"mySecretValue", true}, + // Stripe: publishableKey is public, only secretKey/webhookSecret are secrets + {"stripe", "secretKey", true}, + {"stripe", "webhookSecret", true}, + {"stripe", "SecretKey", true}, // case-insensitive + {"stripe", "publishableKey", false}, + {"stripe", "appId", false}, - // Non-sensitive fields - {"appId", false}, - {"mchId", false}, - {"apiBase", false}, - {"endpoint", false}, - {"merchantNo", false}, - {"paymentMode", false}, - {"notifyUrl", false}, + // Alipay + {"alipay", "privateKey", true}, + {"alipay", "publicKey", true}, + {"alipay", "alipayPublicKey", true}, + {"alipay", "appId", false}, + {"alipay", "notifyUrl", false}, + + // Wxpay + {"wxpay", "privateKey", true}, + {"wxpay", "apiV3Key", true}, + {"wxpay", "publicKey", true}, + {"wxpay", "publicKeyId", false}, + {"wxpay", "certSerial", false}, + {"wxpay", "mchId", false}, + + // EasyPay + {"easypay", "pkey", true}, + {"easypay", "pid", false}, + {"easypay", "apiBase", false}, + + // Unknown provider: never sensitive + {"unknown", "secretKey", false}, } for _, tc := range tests { - t.Run(tc.field, func(t *testing.T) { + tc := tc + t.Run(tc.providerKey+"/"+tc.field, func(t *testing.T) { t.Parallel() - got := isSensitiveConfigField(tc.field) - assert.Equal(t, tc.wantSen, got, "isSensitiveConfigField(%q)", tc.field) + got := isSensitiveProviderConfigField(tc.providerKey, tc.field) + assert.Equal(t, tc.wantSen, got, "isSensitiveProviderConfigField(%q, %q)", tc.providerKey, tc.field) }) } } diff --git a/frontend/src/components/payment/PaymentProviderDialog.vue b/frontend/src/components/payment/PaymentProviderDialog.vue index 10c1bfea..624ddcdd 100644 --- a/frontend/src/components/payment/PaymentProviderDialog.vue +++ b/frontend/src/components/payment/PaymentProviderDialog.vue @@ -88,13 +88,24 @@ v-model="config[field.key]" rows="3" class="input font-mono text-xs" + autocomplete="new-password" + data-1p-ignore + data-lpignore="true" + data-bwignore="true" + spellcheck="false" + :placeholder="editing ? t('admin.accounts.leaveEmptyToKeep') : ''" />
- +
@@ -1391,7 +1391,7 @@ {{ t('admin.settings.oidc.validateIdToken') }}
- +
@@ -3024,7 +3024,7 @@ const form = reactive({ oidc_connect_redirect_url: '', oidc_connect_frontend_redirect_url: '/auth/oidc/callback', oidc_connect_token_auth_method: 'client_secret_post', - oidc_connect_use_pkce: false, + oidc_connect_use_pkce: true, oidc_connect_validate_id_token: true, oidc_connect_allowed_signing_algs: 'RS256,ES256,PS256', oidc_connect_clock_skew_seconds: 120, @@ -3613,8 +3613,8 @@ async function saveSettings() { oidc_connect_redirect_url: form.oidc_connect_redirect_url, oidc_connect_frontend_redirect_url: form.oidc_connect_frontend_redirect_url, oidc_connect_token_auth_method: form.oidc_connect_token_auth_method, - oidc_connect_use_pkce: form.oidc_connect_use_pkce, - oidc_connect_validate_id_token: form.oidc_connect_validate_id_token, + oidc_connect_use_pkce: true, + oidc_connect_validate_id_token: true, oidc_connect_allowed_signing_algs: form.oidc_connect_allowed_signing_algs, oidc_connect_clock_skew_seconds: form.oidc_connect_clock_skew_seconds, oidc_connect_require_email_verified: form.oidc_connect_require_email_verified, From fbd0a2e3c488720025a3408e9db234407d8aef9b Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 16:27:23 +0800 Subject: [PATCH 018/326] feat: carry suggested third-party profile through pending oauth --- .../internal/handler/auth_linuxdo_oauth.go | 201 +++++++++---- .../handler/auth_linuxdo_oauth_test.go | 12 +- .../handler/auth_oauth_pending_flow.go | 263 ++++++++++++++++++ .../handler/auth_oauth_pending_flow_test.go | 40 +++ backend/internal/handler/auth_oidc_oauth.go | 47 +++- .../internal/handler/auth_oidc_oauth_test.go | 20 ++ frontend/src/api/auth.ts | 24 +- 7 files changed, 534 insertions(+), 73 deletions(-) create mode 100644 backend/internal/handler/auth_oauth_pending_flow.go create mode 100644 backend/internal/handler/auth_oauth_pending_flow_test.go diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 0c7c2da7..2f182642 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -87,20 +87,25 @@ func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) { redirectTo = linuxDoOAuthDefaultRedirectTo } + browserSessionKey, err := generateOAuthPendingBrowserSession() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err)) + return + } + secureCookie := isRequestHTTPS(c) setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie) setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie) + setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie) + clearOAuthPendingSessionCookie(c, secureCookie) - codeChallenge := "" - if cfg.UsePKCE { - verifier, err := oauth.GenerateCodeVerifier() - if err != nil { - response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err)) - return - } - codeChallenge = oauth.GenerateCodeChallenge(verifier) - setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie) + verifier, err := oauth.GenerateCodeVerifier() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err)) + return } + codeChallenge := oauth.GenerateCodeChallenge(verifier) + setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie) redirectURI := strings.TrimSpace(cfg.RedirectURL) if redirectURI == "" { @@ -161,14 +166,16 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { if redirectTo == "" { redirectTo = linuxDoOAuthDefaultRedirectTo } + browserSessionKey, _ := readOAuthPendingBrowserCookie(c) + if strings.TrimSpace(browserSessionKey) == "" { + redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "") + return + } - codeVerifier := "" - if cfg.UsePKCE { - codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie) - if codeVerifier == "" { - redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") - return - } + codeVerifier, _ := readCookieDecoded(c, linuxDoOAuthVerifierCookie) + if codeVerifier == "" { + redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") + return } redirectURI := strings.TrimSpace(cfg.RedirectURL) @@ -198,7 +205,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { return } - email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp) + email, username, subject, displayName, avatarURL, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp) if err != nil { log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err) redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "") @@ -215,16 +222,32 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if errors.Is(err, service.ErrOAuthInvitationRequired) { - pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username) - if tokenErr != nil { - redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "") + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: "login", + Identity: service.PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: subject, + }, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: map[string]any{ + "email": email, + "username": username, + "subject": subject, + "suggested_display_name": displayName, + "suggested_avatar_url": avatarURL, + }, + CompletionResponse: map[string]any{ + "error": "invitation_required", + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") return } - fragment := url.Values{} - fragment.Set("error", "invitation_required") - fragment.Set("pending_oauth_token", pendingToken) - fragment.Set("redirect", redirectTo) - redirectWithFragment(c, frontendCallback, fragment) + redirectToFrontendCallback(c, frontendCallback) return } // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。 @@ -232,18 +255,39 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { return } - fragment := url.Values{} - fragment.Set("access_token", tokenPair.AccessToken) - fragment.Set("refresh_token", tokenPair.RefreshToken) - fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn)) - fragment.Set("token_type", "Bearer") - fragment.Set("redirect", redirectTo) - redirectWithFragment(c, frontendCallback, fragment) + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: "login", + Identity: service.PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: subject, + }, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: map[string]any{ + "email": email, + "username": username, + "subject": subject, + "suggested_display_name": displayName, + "suggested_avatar_url": avatarURL, + }, + CompletionResponse: map[string]any{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) } type completeLinuxDoOAuthRequest struct { - PendingOAuthToken string `json:"pending_oauth_token" binding:"required"` - InvitationCode string `json:"invitation_code" binding:"required"` + InvitationCode string `json:"invitation_code" binding:"required"` } // CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating @@ -256,9 +300,38 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { return } - email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken) + secureCookie := isRequestHTTPS(c) + sessionToken, err := readOAuthPendingSessionCookie(c) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"}) + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + pendingSvc, err := h.pendingIdentityService() + if err != nil { + response.ErrorFrom(c, err) + return + } + session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + + email := strings.TrimSpace(session.ResolvedEmail) + username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") + if email == "" || username == "" { + response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")) return } @@ -267,6 +340,14 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) c.JSON(http.StatusOK, gin.H{ "access_token": tokenPair.AccessToken, @@ -303,9 +384,7 @@ func linuxDoExchangeCode( form.Set("client_id", cfg.ClientID) form.Set("code", code) form.Set("redirect_uri", redirectURI) - if cfg.UsePKCE { - form.Set("code_verifier", codeVerifier) - } + form.Set("code_verifier", codeVerifier) r := client.R(). SetContext(ctx). @@ -353,11 +432,11 @@ func linuxDoFetchUserInfo( ctx context.Context, cfg config.LinuxDoConnectConfig, token *linuxDoTokenResponse, -) (email string, username string, subject string, err error) { +) (email string, username string, subject string, displayName string, avatarURL string, err error) { client := req.C().SetTimeout(30 * time.Second) authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken) if err != nil { - return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err) + return "", "", "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err) } resp, err := client.R(). @@ -366,16 +445,16 @@ func linuxDoFetchUserInfo( SetHeader("Authorization", authorization). Get(cfg.UserInfoURL) if err != nil { - return "", "", "", fmt.Errorf("request userinfo: %w", err) + return "", "", "", "", "", fmt.Errorf("request userinfo: %w", err) } if !resp.IsSuccessState() { - return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode) + return "", "", "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode) } return linuxDoParseUserInfo(resp.String(), cfg) } -func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) { +func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, displayName string, avatarURL string, err error) { email = firstNonEmpty( getGJSON(body, cfg.UserInfoEmailPath), getGJSON(body, "email"), @@ -400,12 +479,29 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s getGJSON(body, "user.id"), ) + displayName = firstNonEmpty( + getGJSON(body, "name"), + getGJSON(body, "nickname"), + getGJSON(body, "display_name"), + getGJSON(body, "user.name"), + getGJSON(body, "user.username"), + username, + ) + avatarURL = firstNonEmpty( + getGJSON(body, "avatar_url"), + getGJSON(body, "avatar"), + getGJSON(body, "picture"), + getGJSON(body, "profile_image_url"), + getGJSON(body, "user.avatar"), + getGJSON(body, "user.avatar_url"), + ) + subject = strings.TrimSpace(subject) if subject == "" { - return "", "", "", errors.New("userinfo missing id field") + return "", "", "", "", "", errors.New("userinfo missing id field") } if !isSafeLinuxDoSubject(subject) { - return "", "", "", errors.New("userinfo returned invalid id field") + return "", "", "", "", "", errors.New("userinfo returned invalid id field") } email = strings.TrimSpace(email) @@ -418,8 +514,13 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s if username == "" { username = "linuxdo_" + subject } + displayName = strings.TrimSpace(displayName) + if displayName == "" { + displayName = username + } + avatarURL = strings.TrimSpace(avatarURL) - return email, username, subject, nil + return email, username, subject, displayName, avatarURL, nil } func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) { @@ -436,10 +537,8 @@ func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, cod q.Set("scope", cfg.Scopes) } q.Set("state", state) - if cfg.UsePKCE { - q.Set("code_challenge", codeChallenge) - q.Set("code_challenge_method", "S256") - } + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") u.RawQuery = q.Encode() return u.String(), nil diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index ff169c52..90bc10d1 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -41,11 +41,13 @@ func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) { UserInfoURL: "https://connect.linux.do/api/user", } - email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg) + email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":123,"username":"alice","name":"Alice","avatar_url":"https://cdn.example/avatar.png"}`, cfg) require.NoError(t, err) require.Equal(t, "123", subject) require.Equal(t, "alice", username) require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) + require.Equal(t, "Alice", displayName) + require.Equal(t, "https://cdn.example/avatar.png", avatarURL) } func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) { @@ -53,11 +55,13 @@ func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) { UserInfoURL: "https://connect.linux.do/api/user", } - email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg) + email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg) require.NoError(t, err) require.Equal(t, "123", subject) require.Equal(t, "linuxdo_123", username) require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) + require.Equal(t, "linuxdo_123", displayName) + require.Equal(t, "", avatarURL) } func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) { @@ -65,11 +69,11 @@ func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) { UserInfoURL: "https://connect.linux.do/api/user", } - _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg) + _, _, _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg) require.Error(t, err) tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1) - _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg) + _, _, _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg) require.Error(t, err) } diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go new file mode 100644 index 00000000..a758c0b9 --- /dev/null +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -0,0 +1,263 @@ +package handler + +import ( + "net/http" + "net/url" + "strings" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +const ( + oauthPendingBrowserCookiePath = "/api/v1/auth/oauth" + oauthPendingBrowserCookieName = "oauth_pending_browser_session" + oauthPendingSessionCookiePath = "/api/v1/auth/oauth/pending" + oauthPendingSessionCookieName = "oauth_pending_session" + oauthPendingCookieMaxAgeSec = 10 * 60 + + oauthCompletionResponseKey = "completion_response" +) + +type oauthPendingSessionPayload struct { + Intent string + Identity service.PendingAuthIdentityKey + ResolvedEmail string + RedirectTo string + BrowserSessionKey string + UpstreamIdentityClaims map[string]any + CompletionResponse map[string]any +} + +func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) { + if h == nil || h.authService == nil || h.authService.EntClient() == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + return service.NewAuthPendingIdentityService(h.authService.EntClient()), nil +} + +func generateOAuthPendingBrowserSession() (string, error) { + return oauth.GenerateState() +} + +func setOAuthPendingBrowserCookie(c *gin.Context, sessionKey string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingBrowserCookieName, + Value: encodeCookieValue(sessionKey), + Path: oauthPendingBrowserCookiePath, + MaxAge: oauthPendingCookieMaxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func clearOAuthPendingBrowserCookie(c *gin.Context, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingBrowserCookieName, + Value: "", + Path: oauthPendingBrowserCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func readOAuthPendingBrowserCookie(c *gin.Context) (string, error) { + return readCookieDecoded(c, oauthPendingBrowserCookieName) +} + +func setOAuthPendingSessionCookie(c *gin.Context, sessionToken string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingSessionCookieName, + Value: encodeCookieValue(sessionToken), + Path: oauthPendingSessionCookiePath, + MaxAge: oauthPendingCookieMaxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func clearOAuthPendingSessionCookie(c *gin.Context, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingSessionCookieName, + Value: "", + Path: oauthPendingSessionCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func readOAuthPendingSessionCookie(c *gin.Context) (string, error) { + return readCookieDecoded(c, oauthPendingSessionCookieName) +} + +func redirectToFrontendCallback(c *gin.Context, frontendCallback string) { + u, err := url.Parse(frontendCallback) + if err != nil { + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") { + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + u.Fragment = "" + c.Header("Cache-Control", "no-store") + c.Header("Pragma", "no-cache") + c.Redirect(http.StatusFound, u.String()) +} + +func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPendingSessionPayload) error { + svc, err := h.pendingIdentityService() + if err != nil { + return err + } + + session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{ + Intent: strings.TrimSpace(payload.Intent), + Identity: payload.Identity, + ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail), + RedirectTo: strings.TrimSpace(payload.RedirectTo), + BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey), + UpstreamIdentityClaims: payload.UpstreamIdentityClaims, + LocalFlowState: map[string]any{ + oauthCompletionResponseKey: payload.CompletionResponse, + }, + }) + if err != nil { + return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err) + } + + setOAuthPendingSessionCookie(c, session.SessionToken, isRequestHTTPS(c)) + return nil +} + +func readCompletionResponse(session map[string]any) (map[string]any, bool) { + if len(session) == 0 { + return nil, false + } + value, ok := session[oauthCompletionResponseKey] + if !ok { + return nil, false + } + result, ok := value.(map[string]any) + if !ok { + return nil, false + } + return result, true +} + +func pendingSessionStringValue(values map[string]any, key string) string { + if len(values) == 0 { + return "" + } + raw, ok := values[key] + if !ok { + return "" + } + value, ok := raw.(string) + if !ok { + return "" + } + return strings.TrimSpace(value) +} + +func pendingSessionWantsInvitation(payload map[string]any) bool { + return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required") +} + +func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) { + if len(payload) == 0 || len(upstream) == 0 { + return + } + + displayName := pendingSessionStringValue(upstream, "suggested_display_name") + avatarURL := pendingSessionStringValue(upstream, "suggested_avatar_url") + + if displayName != "" { + if _, exists := payload["suggested_display_name"]; !exists { + payload["suggested_display_name"] = displayName + } + } + if avatarURL != "" { + if _, exists := payload["suggested_avatar_url"]; !exists { + payload["suggested_avatar_url"] = avatarURL + } + } + if displayName != "" || avatarURL != "" { + payload["adoption_required"] = true + } +} + +// ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload. +// POST /api/v1/auth/oauth/pending/exchange +func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { + secureCookie := isRequestHTTPS(c) + clearCookies := func() { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + } + + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil || strings.TrimSpace(sessionToken) == "" { + clearCookies() + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil || strings.TrimSpace(browserSessionKey) == "" { + clearCookies() + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + + svc, err := h.pendingIdentityService() + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + payload, ok := readCompletionResponse(session.LocalFlowState) + if !ok { + clearCookies() + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid")) + return + } + if strings.TrimSpace(session.RedirectTo) != "" { + if _, exists := payload["redirect"]; !exists { + payload["redirect"] = session.RedirectTo + } + } + applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims) + + if pendingSessionWantsInvitation(payload) { + response.Success(c, payload) + return + } + + if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + clearCookies() + response.Success(c, payload) +} diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go new file mode 100644 index 00000000..5517bae2 --- /dev/null +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -0,0 +1,40 @@ +package handler + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestApplySuggestedProfileToCompletionResponse(t *testing.T) { + payload := map[string]any{ + "access_token": "token", + } + upstream := map[string]any{ + "suggested_display_name": "Alice", + "suggested_avatar_url": "https://cdn.example/avatar.png", + } + + applySuggestedProfileToCompletionResponse(payload, upstream) + + require.Equal(t, "Alice", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"]) + require.Equal(t, true, payload["adoption_required"]) +} + +func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *testing.T) { + payload := map[string]any{ + "suggested_display_name": "Existing", + "adoption_required": false, + } + upstream := map[string]any{ + "suggested_display_name": "Alice", + "suggested_avatar_url": "https://cdn.example/avatar.png", + } + + applySuggestedProfileToCompletionResponse(payload, upstream) + + require.Equal(t, "Existing", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"]) + require.Equal(t, true, payload["adoption_required"]) +} diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 37ef6833..e3694c8f 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -87,6 +87,8 @@ type oidcUserInfoClaims struct { Username string Subject string EmailVerified *bool + DisplayName string + AvatarURL string } type oidcJWKSet struct { @@ -338,12 +340,14 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "issuer": issuer, - "email_verified": emailVerified != nil && *emailVerified, - "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "email": email, + "username": username, + "subject": subject, + "issuer": issuer, + "email_verified": emailVerified != nil && *emailVerified, + "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), + "suggested_avatar_url": userInfoClaims.AvatarURL, }, CompletionResponse: map[string]any{ "error": "invitation_required", @@ -371,12 +375,14 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "issuer": issuer, - "email_verified": emailVerified != nil && *emailVerified, - "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "email": email, + "username": username, + "subject": subject, + "issuer": issuer, + "email_verified": emailVerified != nil && *emailVerified, + "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), + "suggested_avatar_url": userInfoClaims.AvatarURL, }, CompletionResponse: map[string]any{ "access_token": tokenPair.AccessToken, @@ -643,9 +649,26 @@ func oidcParseUserInfo(body string, cfg config.OIDCConnectConfig) *oidcUserInfoC if verified, ok := getGJSONBool(body, "email_verified"); ok { claims.EmailVerified = &verified } + claims.DisplayName = firstNonEmpty( + getGJSON(body, "name"), + getGJSON(body, "nickname"), + getGJSON(body, "display_name"), + getGJSON(body, "preferred_username"), + getGJSON(body, "username"), + ) + claims.AvatarURL = firstNonEmpty( + getGJSON(body, "picture"), + getGJSON(body, "avatar_url"), + getGJSON(body, "avatar"), + getGJSON(body, "profile_image_url"), + getGJSON(body, "user.avatar"), + getGJSON(body, "user.avatar_url"), + ) claims.Email = strings.TrimSpace(claims.Email) claims.Username = strings.TrimSpace(claims.Username) claims.Subject = strings.TrimSpace(claims.Subject) + claims.DisplayName = strings.TrimSpace(claims.DisplayName) + claims.AvatarURL = strings.TrimSpace(claims.AvatarURL) return claims } diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index a4cf776a..c389db51 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -91,6 +91,26 @@ func TestOIDCParseAndValidateIDToken(t *testing.T) { require.Error(t, err) } +func TestOIDCParseUserInfoIncludesSuggestedProfile(t *testing.T) { + cfg := config.OIDCConnectConfig{} + + claims := oidcParseUserInfo(`{ + "sub":"subject-1", + "preferred_username":"alice", + "name":"Alice Example", + "picture":"https://cdn.example/avatar.png", + "email":"alice@example.com", + "email_verified":true + }`, cfg) + + require.Equal(t, "subject-1", claims.Subject) + require.Equal(t, "alice", claims.Username) + require.Equal(t, "Alice Example", claims.DisplayName) + require.Equal(t, "https://cdn.example/avatar.png", claims.AvatarURL) + require.NotNil(t, claims.EmailVerified) + require.True(t, *claims.EmailVerified) +} + func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK { n := base64.RawURLEncoding.EncodeToString(pub.N.Bytes()) e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes()) diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index 837c4f4c..d7abcd6a 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -186,6 +186,18 @@ export interface RefreshTokenResponse { token_type: string } +export interface PendingOAuthExchangeResponse { + access_token?: string + refresh_token?: string + expires_in?: number + token_type?: string + redirect?: string + error?: string + adoption_required?: boolean + suggested_display_name?: string + suggested_avatar_url?: string +} + /** * Refresh the access token using the refresh token * @returns New token pair @@ -337,12 +349,10 @@ export async function resetPassword(request: ResetPasswordRequest): Promise { const { data } = await apiClient.post<{ @@ -351,7 +361,6 @@ export async function completeLinuxDoOAuthRegistration( expires_in: number token_type: string }>('/auth/oauth/linuxdo/complete-registration', { - pending_oauth_token: pendingOAuthToken, invitation_code: invitationCode }) return data @@ -359,12 +368,10 @@ export async function completeLinuxDoOAuthRegistration( /** * Complete OIDC OAuth registration by supplying an invitation code - * @param pendingOAuthToken - Short-lived JWT from the OAuth callback * @param invitationCode - Invitation code entered by the user * @returns Token pair on success */ export async function completeOIDCOAuthRegistration( - pendingOAuthToken: string, invitationCode: string ): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> { const { data } = await apiClient.post<{ @@ -373,12 +380,16 @@ export async function completeOIDCOAuthRegistration( expires_in: number token_type: string }>('/auth/oauth/oidc/complete-registration', { - pending_oauth_token: pendingOAuthToken, invitation_code: invitationCode }) return data } +export async function exchangePendingOAuthCompletion(): Promise { + const { data } = await apiClient.post('/auth/oauth/pending/exchange', {}) + return data +} + export const authAPI = { login, login2FA, @@ -402,6 +413,7 @@ export const authAPI = { resetPassword, refreshToken, revokeAllSessions, + exchangePendingOAuthCompletion, completeLinuxDoOAuthRegistration, completeOIDCOAuthRegistration } From e9de839d8791151314873bad11ac9583dabd2738 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 17:39:57 +0800 Subject: [PATCH 019/326] feat: rebuild auth identity foundation flow --- backend/ent/authidentity.go | 266 + backend/ent/authidentity/authidentity.go | 209 + backend/ent/authidentity/where.go | 600 +++ backend/ent/authidentity_create.go | 1036 ++++ backend/ent/authidentity_delete.go | 88 + backend/ent/authidentity_query.go | 797 +++ backend/ent/authidentity_update.go | 923 ++++ backend/ent/authidentitychannel.go | 228 + .../authidentitychannel.go | 153 + backend/ent/authidentitychannel/where.go | 559 ++ backend/ent/authidentitychannel_create.go | 932 ++++ backend/ent/authidentitychannel_delete.go | 88 + backend/ent/authidentitychannel_query.go | 643 +++ backend/ent/authidentitychannel_update.go | 581 ++ backend/ent/client.go | 876 ++- backend/ent/ent.go | 60 +- backend/ent/hook/hook.go | 48 + backend/ent/identityadoptiondecision.go | 223 + .../identityadoptiondecision.go | 159 + backend/ent/identityadoptiondecision/where.go | 342 ++ .../ent/identityadoptiondecision_create.go | 843 +++ .../ent/identityadoptiondecision_delete.go | 88 + backend/ent/identityadoptiondecision_query.go | 721 +++ .../ent/identityadoptiondecision_update.go | 532 ++ backend/ent/intercept/intercept.go | 120 + backend/ent/migrate/schema.go | 216 + backend/ent/mutation.go | 4687 ++++++++++++++++- backend/ent/pendingauthsession.go | 399 ++ .../pendingauthsession/pendingauthsession.go | 279 + backend/ent/pendingauthsession/where.go | 1262 +++++ backend/ent/pendingauthsession_create.go | 1815 +++++++ backend/ent/pendingauthsession_delete.go | 88 + backend/ent/pendingauthsession_query.go | 717 +++ backend/ent/pendingauthsession_update.go | 1178 +++++ backend/ent/predicate/predicate.go | 12 + backend/ent/runtime/runtime.go | 266 +- backend/ent/schema/auth_identity.go | 93 + backend/ent/schema/auth_identity_channel.go | 72 + .../ent/schema/auth_identity_schema_test.go | 124 + .../ent/schema/identity_adoption_decision.go | 70 + backend/ent/schema/pending_auth_session.go | 134 + backend/ent/schema/user.go | 13 + backend/ent/tx.go | 12 + backend/ent/user.go | 79 +- backend/ent/user/user.go | 88 + backend/ent/user/where.go | 226 + backend/ent/user_create.go | 290 + backend/ent/user_query.go | 155 +- backend/ent/user_update.go | 474 ++ backend/internal/config/config.go | 15 +- .../internal/handler/admin/setting_handler.go | 189 +- ...tting_handler_auth_source_defaults_test.go | 149 + .../internal/handler/auth_linuxdo_oauth.go | 21 +- .../handler/auth_linuxdo_oauth_test.go | 87 + .../handler/auth_oauth_pending_flow.go | 325 ++ .../handler/auth_oauth_pending_flow_test.go | 457 ++ backend/internal/handler/auth_oidc_oauth.go | 21 +- .../internal/handler/auth_oidc_oauth_test.go | 84 + backend/internal/handler/auth_wechat_oauth.go | 618 +++ .../handler/auth_wechat_oauth_test.go | 411 ++ backend/internal/handler/dto/settings.go | 1 + .../handler/payment_webhook_handler.go | 2 +- .../handler/payment_webhook_handler_test.go | 34 + backend/internal/handler/setting_handler.go | 1 + backend/internal/handler/user_handler.go | 30 +- backend/internal/handler/user_handler_test.go | 136 + backend/internal/repository/api_key_repo.go | 6 + .../auth_identity_migration_report.go | 148 + .../repository/user_profile_identity_repo.go | 544 ++ ...ser_profile_identity_repo_contract_test.go | 428 ++ backend/internal/repository/user_repo.go | 44 + .../user_repo_sort_integration_test.go | 83 + backend/internal/server/api_contract_test.go | 4 +- backend/internal/server/routes/auth.go | 14 + .../service/admin_service_apikey_test.go | 9 + .../service/admin_service_delete_test.go | 12 + .../service/auth_pending_identity_service.go | 326 ++ .../auth_pending_identity_service_test.go | 224 + backend/internal/service/auth_service.go | 164 +- .../auth_service_identity_sync_test.go | 153 + .../auth_service_pending_oauth_test.go | 146 - backend/internal/service/domain_constants.go | 26 + .../service/openai_account_scheduler.go | 77 +- .../service/openai_account_scheduler_test.go | 430 +- ...enai_account_scheduler_ws_snapshot_test.go | 7 +- .../service/payment_config_service.go | 96 +- .../service/payment_config_service_test.go | 186 + .../service/payment_resume_service.go | 248 + .../service/payment_resume_service_test.go | 240 + backend/internal/service/payment_service.go | 40 +- backend/internal/service/setting_service.go | 256 +- ...tting_service_auth_source_defaults_test.go | 136 + backend/internal/service/settings_view.go | 1 + backend/internal/service/user.go | 34 +- backend/internal/service/user_service.go | 153 +- backend/internal/service/user_service_test.go | 180 +- .../108_auth_identity_foundation_core.sql | 141 + .../109_auth_identity_compat_backfill.sql | 125 + ...nding_auth_and_provider_default_grants.sql | 60 + ...11_payment_routing_and_scheduler_flags.sql | 8 + .../api/__tests__/auth-oauth-adoption.spec.ts | 60 + .../settings.authSourceDefaults.spec.ts | 118 + frontend/src/api/admin/settings.ts | 127 + frontend/src/api/auth.ts | 41 +- .../components/auth/WechatOAuthSection.vue | 53 + .../auth/__tests__/WechatOAuthSection.spec.ts | 74 + .../components/payment/PaymentStatusPanel.vue | 29 +- .../payment/__tests__/paymentFlow.spec.ts | 163 + .../src/components/payment/paymentFlow.ts | 197 + .../src/router/__tests__/wechat-route.spec.ts | 55 + frontend/src/router/index.ts | 9 + frontend/src/stores/app.ts | 1 + frontend/src/types/index.ts | 1 + frontend/src/views/admin/SettingsView.vue | 497 +- .../src/views/auth/LinuxDoCallbackView.vue | 263 +- frontend/src/views/auth/LoginView.vue | 10 +- frontend/src/views/auth/OidcCallbackView.vue | 267 +- frontend/src/views/auth/RegisterView.vue | 10 +- .../src/views/auth/WechatCallbackView.vue | 361 ++ .../__tests__/LinuxDoCallbackView.spec.ts | 180 + .../auth/__tests__/OidcCallbackView.spec.ts | 191 + .../auth/__tests__/WechatCallbackView.spec.ts | 241 + frontend/src/views/user/PaymentView.vue | 229 +- 123 files changed, 33599 insertions(+), 772 deletions(-) create mode 100644 backend/ent/authidentity.go create mode 100644 backend/ent/authidentity/authidentity.go create mode 100644 backend/ent/authidentity/where.go create mode 100644 backend/ent/authidentity_create.go create mode 100644 backend/ent/authidentity_delete.go create mode 100644 backend/ent/authidentity_query.go create mode 100644 backend/ent/authidentity_update.go create mode 100644 backend/ent/authidentitychannel.go create mode 100644 backend/ent/authidentitychannel/authidentitychannel.go create mode 100644 backend/ent/authidentitychannel/where.go create mode 100644 backend/ent/authidentitychannel_create.go create mode 100644 backend/ent/authidentitychannel_delete.go create mode 100644 backend/ent/authidentitychannel_query.go create mode 100644 backend/ent/authidentitychannel_update.go create mode 100644 backend/ent/identityadoptiondecision.go create mode 100644 backend/ent/identityadoptiondecision/identityadoptiondecision.go create mode 100644 backend/ent/identityadoptiondecision/where.go create mode 100644 backend/ent/identityadoptiondecision_create.go create mode 100644 backend/ent/identityadoptiondecision_delete.go create mode 100644 backend/ent/identityadoptiondecision_query.go create mode 100644 backend/ent/identityadoptiondecision_update.go create mode 100644 backend/ent/pendingauthsession.go create mode 100644 backend/ent/pendingauthsession/pendingauthsession.go create mode 100644 backend/ent/pendingauthsession/where.go create mode 100644 backend/ent/pendingauthsession_create.go create mode 100644 backend/ent/pendingauthsession_delete.go create mode 100644 backend/ent/pendingauthsession_query.go create mode 100644 backend/ent/pendingauthsession_update.go create mode 100644 backend/ent/schema/auth_identity.go create mode 100644 backend/ent/schema/auth_identity_channel.go create mode 100644 backend/ent/schema/auth_identity_schema_test.go create mode 100644 backend/ent/schema/identity_adoption_decision.go create mode 100644 backend/ent/schema/pending_auth_session.go create mode 100644 backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go create mode 100644 backend/internal/handler/auth_wechat_oauth.go create mode 100644 backend/internal/handler/auth_wechat_oauth_test.go create mode 100644 backend/internal/handler/user_handler_test.go create mode 100644 backend/internal/repository/auth_identity_migration_report.go create mode 100644 backend/internal/repository/user_profile_identity_repo.go create mode 100644 backend/internal/repository/user_profile_identity_repo_contract_test.go create mode 100644 backend/internal/service/auth_pending_identity_service.go create mode 100644 backend/internal/service/auth_pending_identity_service_test.go create mode 100644 backend/internal/service/auth_service_identity_sync_test.go delete mode 100644 backend/internal/service/auth_service_pending_oauth_test.go create mode 100644 backend/internal/service/payment_resume_service.go create mode 100644 backend/internal/service/payment_resume_service_test.go create mode 100644 backend/internal/service/setting_service_auth_source_defaults_test.go create mode 100644 backend/migrations/108_auth_identity_foundation_core.sql create mode 100644 backend/migrations/109_auth_identity_compat_backfill.sql create mode 100644 backend/migrations/110_pending_auth_and_provider_default_grants.sql create mode 100644 backend/migrations/111_payment_routing_and_scheduler_flags.sql create mode 100644 frontend/src/api/__tests__/auth-oauth-adoption.spec.ts create mode 100644 frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts create mode 100644 frontend/src/components/auth/WechatOAuthSection.vue create mode 100644 frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts create mode 100644 frontend/src/components/payment/__tests__/paymentFlow.spec.ts create mode 100644 frontend/src/components/payment/paymentFlow.ts create mode 100644 frontend/src/router/__tests__/wechat-route.spec.ts create mode 100644 frontend/src/views/auth/WechatCallbackView.vue create mode 100644 frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts create mode 100644 frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts create mode 100644 frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts diff --git a/backend/ent/authidentity.go b/backend/ent/authidentity.go new file mode 100644 index 00000000..5ccfcf19 --- /dev/null +++ b/backend/ent/authidentity.go @@ -0,0 +1,266 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentity is the model entity for the AuthIdentity schema. +type AuthIdentity struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // UserID holds the value of the "user_id" field. + UserID int64 `json:"user_id,omitempty"` + // ProviderType holds the value of the "provider_type" field. + ProviderType string `json:"provider_type,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey string `json:"provider_key,omitempty"` + // ProviderSubject holds the value of the "provider_subject" field. + ProviderSubject string `json:"provider_subject,omitempty"` + // VerifiedAt holds the value of the "verified_at" field. + VerifiedAt *time.Time `json:"verified_at,omitempty"` + // Issuer holds the value of the "issuer" field. + Issuer *string `json:"issuer,omitempty"` + // Metadata holds the value of the "metadata" field. + Metadata map[string]interface{} `json:"metadata,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the AuthIdentityQuery when eager-loading is set. + Edges AuthIdentityEdges `json:"edges"` + selectValues sql.SelectValues +} + +// AuthIdentityEdges holds the relations/edges for other nodes in the graph. +type AuthIdentityEdges struct { + // User holds the value of the user edge. + User *User `json:"user,omitempty"` + // Channels holds the value of the channels edge. + Channels []*AuthIdentityChannel `json:"channels,omitempty"` + // AdoptionDecisions holds the value of the adoption_decisions edge. + AdoptionDecisions []*IdentityAdoptionDecision `json:"adoption_decisions,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [3]bool +} + +// UserOrErr returns the User value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e AuthIdentityEdges) UserOrErr() (*User, error) { + if e.User != nil { + return e.User, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: user.Label} + } + return nil, &NotLoadedError{edge: "user"} +} + +// ChannelsOrErr returns the Channels value or an error if the edge +// was not loaded in eager-loading. +func (e AuthIdentityEdges) ChannelsOrErr() ([]*AuthIdentityChannel, error) { + if e.loadedTypes[1] { + return e.Channels, nil + } + return nil, &NotLoadedError{edge: "channels"} +} + +// AdoptionDecisionsOrErr returns the AdoptionDecisions value or an error if the edge +// was not loaded in eager-loading. +func (e AuthIdentityEdges) AdoptionDecisionsOrErr() ([]*IdentityAdoptionDecision, error) { + if e.loadedTypes[2] { + return e.AdoptionDecisions, nil + } + return nil, &NotLoadedError{edge: "adoption_decisions"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*AuthIdentity) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case authidentity.FieldMetadata: + values[i] = new([]byte) + case authidentity.FieldID, authidentity.FieldUserID: + values[i] = new(sql.NullInt64) + case authidentity.FieldProviderType, authidentity.FieldProviderKey, authidentity.FieldProviderSubject, authidentity.FieldIssuer: + values[i] = new(sql.NullString) + case authidentity.FieldCreatedAt, authidentity.FieldUpdatedAt, authidentity.FieldVerifiedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the AuthIdentity fields. +func (_m *AuthIdentity) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case authidentity.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case authidentity.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case authidentity.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case authidentity.FieldUserID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field user_id", values[i]) + } else if value.Valid { + _m.UserID = value.Int64 + } + case authidentity.FieldProviderType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_type", values[i]) + } else if value.Valid { + _m.ProviderType = value.String + } + case authidentity.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = value.String + } + case authidentity.FieldProviderSubject: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_subject", values[i]) + } else if value.Valid { + _m.ProviderSubject = value.String + } + case authidentity.FieldVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field verified_at", values[i]) + } else if value.Valid { + _m.VerifiedAt = new(time.Time) + *_m.VerifiedAt = value.Time + } + case authidentity.FieldIssuer: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field issuer", values[i]) + } else if value.Valid { + _m.Issuer = new(string) + *_m.Issuer = value.String + } + case authidentity.FieldMetadata: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field metadata", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Metadata); err != nil { + return fmt.Errorf("unmarshal field metadata: %w", err) + } + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentity. +// This includes values selected through modifiers, order, etc. +func (_m *AuthIdentity) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryUser queries the "user" edge of the AuthIdentity entity. +func (_m *AuthIdentity) QueryUser() *UserQuery { + return NewAuthIdentityClient(_m.config).QueryUser(_m) +} + +// QueryChannels queries the "channels" edge of the AuthIdentity entity. +func (_m *AuthIdentity) QueryChannels() *AuthIdentityChannelQuery { + return NewAuthIdentityClient(_m.config).QueryChannels(_m) +} + +// QueryAdoptionDecisions queries the "adoption_decisions" edge of the AuthIdentity entity. +func (_m *AuthIdentity) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery { + return NewAuthIdentityClient(_m.config).QueryAdoptionDecisions(_m) +} + +// Update returns a builder for updating this AuthIdentity. +// Note that you need to call AuthIdentity.Unwrap() before calling this method if this AuthIdentity +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *AuthIdentity) Update() *AuthIdentityUpdateOne { + return NewAuthIdentityClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the AuthIdentity entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *AuthIdentity) Unwrap() *AuthIdentity { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: AuthIdentity is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *AuthIdentity) String() string { + var builder strings.Builder + builder.WriteString("AuthIdentity(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("user_id=") + builder.WriteString(fmt.Sprintf("%v", _m.UserID)) + builder.WriteString(", ") + builder.WriteString("provider_type=") + builder.WriteString(_m.ProviderType) + builder.WriteString(", ") + builder.WriteString("provider_key=") + builder.WriteString(_m.ProviderKey) + builder.WriteString(", ") + builder.WriteString("provider_subject=") + builder.WriteString(_m.ProviderSubject) + builder.WriteString(", ") + if v := _m.VerifiedAt; v != nil { + builder.WriteString("verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.Issuer; v != nil { + builder.WriteString("issuer=") + builder.WriteString(*v) + } + builder.WriteString(", ") + builder.WriteString("metadata=") + builder.WriteString(fmt.Sprintf("%v", _m.Metadata)) + builder.WriteByte(')') + return builder.String() +} + +// AuthIdentities is a parsable slice of AuthIdentity. +type AuthIdentities []*AuthIdentity diff --git a/backend/ent/authidentity/authidentity.go b/backend/ent/authidentity/authidentity.go new file mode 100644 index 00000000..c90be759 --- /dev/null +++ b/backend/ent/authidentity/authidentity.go @@ -0,0 +1,209 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentity + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the authidentity type in the database. + Label = "auth_identity" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldUserID holds the string denoting the user_id field in the database. + FieldUserID = "user_id" + // FieldProviderType holds the string denoting the provider_type field in the database. + FieldProviderType = "provider_type" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" + // FieldProviderSubject holds the string denoting the provider_subject field in the database. + FieldProviderSubject = "provider_subject" + // FieldVerifiedAt holds the string denoting the verified_at field in the database. + FieldVerifiedAt = "verified_at" + // FieldIssuer holds the string denoting the issuer field in the database. + FieldIssuer = "issuer" + // FieldMetadata holds the string denoting the metadata field in the database. + FieldMetadata = "metadata" + // EdgeUser holds the string denoting the user edge name in mutations. + EdgeUser = "user" + // EdgeChannels holds the string denoting the channels edge name in mutations. + EdgeChannels = "channels" + // EdgeAdoptionDecisions holds the string denoting the adoption_decisions edge name in mutations. + EdgeAdoptionDecisions = "adoption_decisions" + // Table holds the table name of the authidentity in the database. + Table = "auth_identities" + // UserTable is the table that holds the user relation/edge. + UserTable = "auth_identities" + // UserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + UserInverseTable = "users" + // UserColumn is the table column denoting the user relation/edge. + UserColumn = "user_id" + // ChannelsTable is the table that holds the channels relation/edge. + ChannelsTable = "auth_identity_channels" + // ChannelsInverseTable is the table name for the AuthIdentityChannel entity. + // It exists in this package in order to avoid circular dependency with the "authidentitychannel" package. + ChannelsInverseTable = "auth_identity_channels" + // ChannelsColumn is the table column denoting the channels relation/edge. + ChannelsColumn = "identity_id" + // AdoptionDecisionsTable is the table that holds the adoption_decisions relation/edge. + AdoptionDecisionsTable = "identity_adoption_decisions" + // AdoptionDecisionsInverseTable is the table name for the IdentityAdoptionDecision entity. + // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package. + AdoptionDecisionsInverseTable = "identity_adoption_decisions" + // AdoptionDecisionsColumn is the table column denoting the adoption_decisions relation/edge. + AdoptionDecisionsColumn = "identity_id" +) + +// Columns holds all SQL columns for authidentity fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldUserID, + FieldProviderType, + FieldProviderKey, + FieldProviderSubject, + FieldVerifiedAt, + FieldIssuer, + FieldMetadata, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + ProviderTypeValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error + // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + ProviderSubjectValidator func(string) error + // DefaultMetadata holds the default value on creation for the "metadata" field. + DefaultMetadata func() map[string]interface{} +) + +// OrderOption defines the ordering options for the AuthIdentity queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByUserID orders the results by the user_id field. +func ByUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUserID, opts...).ToFunc() +} + +// ByProviderType orders the results by the provider_type field. +func ByProviderType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderType, opts...).ToFunc() +} + +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + +// ByProviderSubject orders the results by the provider_subject field. +func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderSubject, opts...).ToFunc() +} + +// ByVerifiedAt orders the results by the verified_at field. +func ByVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldVerifiedAt, opts...).ToFunc() +} + +// ByIssuer orders the results by the issuer field. +func ByIssuer(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIssuer, opts...).ToFunc() +} + +// ByUserField orders the results by user field. +func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...)) + } +} + +// ByChannelsCount orders the results by channels count. +func ByChannelsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newChannelsStep(), opts...) + } +} + +// ByChannels orders the results by channels terms. +func ByChannels(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newChannelsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByAdoptionDecisionsCount orders the results by adoption_decisions count. +func ByAdoptionDecisionsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAdoptionDecisionsStep(), opts...) + } +} + +// ByAdoptionDecisions orders the results by adoption_decisions terms. +func ByAdoptionDecisions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) +} +func newChannelsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(ChannelsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn), + ) +} +func newAdoptionDecisionsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AdoptionDecisionsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn), + ) +} diff --git a/backend/ent/authidentity/where.go b/backend/ent/authidentity/where.go new file mode 100644 index 00000000..3dbf3178 --- /dev/null +++ b/backend/ent/authidentity/where.go @@ -0,0 +1,600 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentity + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ. +func UserID(v int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v)) +} + +// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ. +func ProviderType(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ. +func ProviderSubject(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v)) +} + +// VerifiedAt applies equality check predicate on the "verified_at" field. It's identical to VerifiedAtEQ. +func VerifiedAt(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v)) +} + +// Issuer applies equality check predicate on the "issuer" field. It's identical to IssuerEQ. +func Issuer(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// UserIDEQ applies the EQ predicate on the "user_id" field. +func UserIDEQ(v int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v)) +} + +// UserIDNEQ applies the NEQ predicate on the "user_id" field. +func UserIDNEQ(v int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldUserID, v)) +} + +// UserIDIn applies the In predicate on the "user_id" field. +func UserIDIn(vs ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldUserID, vs...)) +} + +// UserIDNotIn applies the NotIn predicate on the "user_id" field. +func UserIDNotIn(vs ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldUserID, vs...)) +} + +// ProviderTypeEQ applies the EQ predicate on the "provider_type" field. +func ProviderTypeEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field. +func ProviderTypeNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderType, v)) +} + +// ProviderTypeIn applies the In predicate on the "provider_type" field. +func ProviderTypeIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldProviderType, vs...)) +} + +// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field. +func ProviderTypeNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderType, vs...)) +} + +// ProviderTypeGT applies the GT predicate on the "provider_type" field. +func ProviderTypeGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldProviderType, v)) +} + +// ProviderTypeGTE applies the GTE predicate on the "provider_type" field. +func ProviderTypeGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldProviderType, v)) +} + +// ProviderTypeLT applies the LT predicate on the "provider_type" field. +func ProviderTypeLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldProviderType, v)) +} + +// ProviderTypeLTE applies the LTE predicate on the "provider_type" field. +func ProviderTypeLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldProviderType, v)) +} + +// ProviderTypeContains applies the Contains predicate on the "provider_type" field. +func ProviderTypeContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldProviderType, v)) +} + +// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field. +func ProviderTypeHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderType, v)) +} + +// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field. +func ProviderTypeHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderType, v)) +} + +// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field. +func ProviderTypeEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderType, v)) +} + +// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field. +func ProviderTypeContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderType, v)) +} + +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderKey, v)) +} + +// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field. +func ProviderSubjectEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field. +func ProviderSubjectNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectIn applies the In predicate on the "provider_subject" field. +func ProviderSubjectIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field. +func ProviderSubjectNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectGT applies the GT predicate on the "provider_subject" field. +func ProviderSubjectGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldProviderSubject, v)) +} + +// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field. +func ProviderSubjectGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldProviderSubject, v)) +} + +// ProviderSubjectLT applies the LT predicate on the "provider_subject" field. +func ProviderSubjectLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldProviderSubject, v)) +} + +// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field. +func ProviderSubjectLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldProviderSubject, v)) +} + +// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field. +func ProviderSubjectContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldProviderSubject, v)) +} + +// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field. +func ProviderSubjectHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderSubject, v)) +} + +// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field. +func ProviderSubjectHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderSubject, v)) +} + +// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field. +func ProviderSubjectEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderSubject, v)) +} + +// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field. +func ProviderSubjectContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderSubject, v)) +} + +// VerifiedAtEQ applies the EQ predicate on the "verified_at" field. +func VerifiedAtEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v)) +} + +// VerifiedAtNEQ applies the NEQ predicate on the "verified_at" field. +func VerifiedAtNEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldVerifiedAt, v)) +} + +// VerifiedAtIn applies the In predicate on the "verified_at" field. +func VerifiedAtIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldVerifiedAt, vs...)) +} + +// VerifiedAtNotIn applies the NotIn predicate on the "verified_at" field. +func VerifiedAtNotIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldVerifiedAt, vs...)) +} + +// VerifiedAtGT applies the GT predicate on the "verified_at" field. +func VerifiedAtGT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldVerifiedAt, v)) +} + +// VerifiedAtGTE applies the GTE predicate on the "verified_at" field. +func VerifiedAtGTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldVerifiedAt, v)) +} + +// VerifiedAtLT applies the LT predicate on the "verified_at" field. +func VerifiedAtLT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldVerifiedAt, v)) +} + +// VerifiedAtLTE applies the LTE predicate on the "verified_at" field. +func VerifiedAtLTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldVerifiedAt, v)) +} + +// VerifiedAtIsNil applies the IsNil predicate on the "verified_at" field. +func VerifiedAtIsNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIsNull(FieldVerifiedAt)) +} + +// VerifiedAtNotNil applies the NotNil predicate on the "verified_at" field. +func VerifiedAtNotNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotNull(FieldVerifiedAt)) +} + +// IssuerEQ applies the EQ predicate on the "issuer" field. +func IssuerEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v)) +} + +// IssuerNEQ applies the NEQ predicate on the "issuer" field. +func IssuerNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldIssuer, v)) +} + +// IssuerIn applies the In predicate on the "issuer" field. +func IssuerIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldIssuer, vs...)) +} + +// IssuerNotIn applies the NotIn predicate on the "issuer" field. +func IssuerNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldIssuer, vs...)) +} + +// IssuerGT applies the GT predicate on the "issuer" field. +func IssuerGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldIssuer, v)) +} + +// IssuerGTE applies the GTE predicate on the "issuer" field. +func IssuerGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldIssuer, v)) +} + +// IssuerLT applies the LT predicate on the "issuer" field. +func IssuerLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldIssuer, v)) +} + +// IssuerLTE applies the LTE predicate on the "issuer" field. +func IssuerLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldIssuer, v)) +} + +// IssuerContains applies the Contains predicate on the "issuer" field. +func IssuerContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldIssuer, v)) +} + +// IssuerHasPrefix applies the HasPrefix predicate on the "issuer" field. +func IssuerHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldIssuer, v)) +} + +// IssuerHasSuffix applies the HasSuffix predicate on the "issuer" field. +func IssuerHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldIssuer, v)) +} + +// IssuerIsNil applies the IsNil predicate on the "issuer" field. +func IssuerIsNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIsNull(FieldIssuer)) +} + +// IssuerNotNil applies the NotNil predicate on the "issuer" field. +func IssuerNotNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotNull(FieldIssuer)) +} + +// IssuerEqualFold applies the EqualFold predicate on the "issuer" field. +func IssuerEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldIssuer, v)) +} + +// IssuerContainsFold applies the ContainsFold predicate on the "issuer" field. +func IssuerContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldIssuer, v)) +} + +// HasUser applies the HasEdge predicate on the "user" edge. +func HasUser() predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates). +func HasUserWith(preds ...predicate.User) predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := newUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasChannels applies the HasEdge predicate on the "channels" edge. +func HasChannels() predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasChannelsWith applies the HasEdge predicate on the "channels" edge with a given conditions (other predicates). +func HasChannelsWith(preds ...predicate.AuthIdentityChannel) predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := newChannelsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasAdoptionDecisions applies the HasEdge predicate on the "adoption_decisions" edge. +func HasAdoptionDecisions() predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAdoptionDecisionsWith applies the HasEdge predicate on the "adoption_decisions" edge with a given conditions (other predicates). +func HasAdoptionDecisionsWith(preds ...predicate.IdentityAdoptionDecision) predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := newAdoptionDecisionsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.AuthIdentity) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.AuthIdentity) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.AuthIdentity) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.NotPredicates(p)) +} diff --git a/backend/ent/authidentity_create.go b/backend/ent/authidentity_create.go new file mode 100644 index 00000000..e287705c --- /dev/null +++ b/backend/ent/authidentity_create.go @@ -0,0 +1,1036 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentityCreate is the builder for creating a AuthIdentity entity. +type AuthIdentityCreate struct { + config + mutation *AuthIdentityMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *AuthIdentityCreate) SetCreatedAt(v time.Time) *AuthIdentityCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *AuthIdentityCreate) SetUpdatedAt(v time.Time) *AuthIdentityCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetUserID sets the "user_id" field. +func (_c *AuthIdentityCreate) SetUserID(v int64) *AuthIdentityCreate { + _c.mutation.SetUserID(v) + return _c +} + +// SetProviderType sets the "provider_type" field. +func (_c *AuthIdentityCreate) SetProviderType(v string) *AuthIdentityCreate { + _c.mutation.SetProviderType(v) + return _c +} + +// SetProviderKey sets the "provider_key" field. +func (_c *AuthIdentityCreate) SetProviderKey(v string) *AuthIdentityCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetProviderSubject sets the "provider_subject" field. +func (_c *AuthIdentityCreate) SetProviderSubject(v string) *AuthIdentityCreate { + _c.mutation.SetProviderSubject(v) + return _c +} + +// SetVerifiedAt sets the "verified_at" field. +func (_c *AuthIdentityCreate) SetVerifiedAt(v time.Time) *AuthIdentityCreate { + _c.mutation.SetVerifiedAt(v) + return _c +} + +// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityCreate { + if v != nil { + _c.SetVerifiedAt(*v) + } + return _c +} + +// SetIssuer sets the "issuer" field. +func (_c *AuthIdentityCreate) SetIssuer(v string) *AuthIdentityCreate { + _c.mutation.SetIssuer(v) + return _c +} + +// SetNillableIssuer sets the "issuer" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableIssuer(v *string) *AuthIdentityCreate { + if v != nil { + _c.SetIssuer(*v) + } + return _c +} + +// SetMetadata sets the "metadata" field. +func (_c *AuthIdentityCreate) SetMetadata(v map[string]interface{}) *AuthIdentityCreate { + _c.mutation.SetMetadata(v) + return _c +} + +// SetUser sets the "user" edge to the User entity. +func (_c *AuthIdentityCreate) SetUser(v *User) *AuthIdentityCreate { + return _c.SetUserID(v.ID) +} + +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs. +func (_c *AuthIdentityCreate) AddChannelIDs(ids ...int64) *AuthIdentityCreate { + _c.mutation.AddChannelIDs(ids...) + return _c +} + +// AddChannels adds the "channels" edges to the AuthIdentityChannel entity. +func (_c *AuthIdentityCreate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddChannelIDs(ids...) +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (_c *AuthIdentityCreate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityCreate { + _c.mutation.AddAdoptionDecisionIDs(ids...) + return _c +} + +// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_c *AuthIdentityCreate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddAdoptionDecisionIDs(ids...) +} + +// Mutation returns the AuthIdentityMutation object of the builder. +func (_c *AuthIdentityCreate) Mutation() *AuthIdentityMutation { + return _c.mutation +} + +// Save creates the AuthIdentity in the database. +func (_c *AuthIdentityCreate) Save(ctx context.Context) (*AuthIdentity, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *AuthIdentityCreate) SaveX(ctx context.Context) *AuthIdentity { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *AuthIdentityCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := authidentity.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := authidentity.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.Metadata(); !ok { + v := authidentity.DefaultMetadata() + _c.mutation.SetMetadata(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *AuthIdentityCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentity.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentity.updated_at"`)} + } + if _, ok := _c.mutation.UserID(); !ok { + return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "AuthIdentity.user_id"`)} + } + if _, ok := _c.mutation.ProviderType(); !ok { + return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentity.provider_type"`)} + } + if v, ok := _c.mutation.ProviderType(); ok { + if err := authidentity.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderKey(); !ok { + return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentity.provider_key"`)} + } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := authidentity.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderSubject(); !ok { + return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "AuthIdentity.provider_subject"`)} + } + if v, ok := _c.mutation.ProviderSubject(); ok { + if err := authidentity.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)} + } + } + if _, ok := _c.mutation.Metadata(); !ok { + return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentity.metadata"`)} + } + if len(_c.mutation.UserIDs()) == 0 { + return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "AuthIdentity.user"`)} + } + return nil +} + +func (_c *AuthIdentityCreate) sqlSave(ctx context.Context) (*AuthIdentity, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *AuthIdentityCreate) createSpec() (*AuthIdentity, *sqlgraph.CreateSpec) { + var ( + _node = &AuthIdentity{config: _c.config} + _spec = sqlgraph.NewCreateSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(authidentity.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.ProviderType(); ok { + _spec.SetField(authidentity.FieldProviderType, field.TypeString, value) + _node.ProviderType = value + } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = value + } + if value, ok := _c.mutation.ProviderSubject(); ok { + _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value) + _node.ProviderSubject = value + } + if value, ok := _c.mutation.VerifiedAt(); ok { + _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value) + _node.VerifiedAt = &value + } + if value, ok := _c.mutation.Issuer(); ok { + _spec.SetField(authidentity.FieldIssuer, field.TypeString, value) + _node.Issuer = &value + } + if value, ok := _c.mutation.Metadata(); ok { + _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value) + _node.Metadata = value + } + if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.UserID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.ChannelsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentity.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertOne { + _c.conflict = opts + return &AuthIdentityUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityCreate) OnConflictColumns(columns ...string) *AuthIdentityUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityUpsertOne{ + create: _c, + } +} + +type ( + // AuthIdentityUpsertOne is the builder for "upsert"-ing + // one AuthIdentity node. + AuthIdentityUpsertOne struct { + create *AuthIdentityCreate + } + + // AuthIdentityUpsert is the "OnConflict" setter. + AuthIdentityUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityUpsert) SetUpdatedAt(v time.Time) *AuthIdentityUpsert { + u.Set(authidentity.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateUpdatedAt() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldUpdatedAt) + return u +} + +// SetUserID sets the "user_id" field. +func (u *AuthIdentityUpsert) SetUserID(v int64) *AuthIdentityUpsert { + u.Set(authidentity.FieldUserID, v) + return u +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateUserID() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldUserID) + return u +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityUpsert) SetProviderType(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldProviderType, v) + return u +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateProviderType() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldProviderType) + return u +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityUpsert) SetProviderKey(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateProviderKey() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldProviderKey) + return u +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *AuthIdentityUpsert) SetProviderSubject(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldProviderSubject, v) + return u +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateProviderSubject() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldProviderSubject) + return u +} + +// SetVerifiedAt sets the "verified_at" field. +func (u *AuthIdentityUpsert) SetVerifiedAt(v time.Time) *AuthIdentityUpsert { + u.Set(authidentity.FieldVerifiedAt, v) + return u +} + +// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateVerifiedAt() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldVerifiedAt) + return u +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (u *AuthIdentityUpsert) ClearVerifiedAt() *AuthIdentityUpsert { + u.SetNull(authidentity.FieldVerifiedAt) + return u +} + +// SetIssuer sets the "issuer" field. +func (u *AuthIdentityUpsert) SetIssuer(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldIssuer, v) + return u +} + +// UpdateIssuer sets the "issuer" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateIssuer() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldIssuer) + return u +} + +// ClearIssuer clears the value of the "issuer" field. +func (u *AuthIdentityUpsert) ClearIssuer() *AuthIdentityUpsert { + u.SetNull(authidentity.FieldIssuer) + return u +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityUpsert { + u.Set(authidentity.FieldMetadata, v) + return u +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateMetadata() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldMetadata) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityUpsertOne) UpdateNewValues() *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(authidentity.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityUpsertOne) Ignore() *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityUpsertOne) DoNothing() *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreate.OnConflict +// documentation for more info. +func (u *AuthIdentityUpsertOne) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateUpdatedAt() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetUserID sets the "user_id" field. +func (u *AuthIdentityUpsertOne) SetUserID(v int64) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateUserID() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUserID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityUpsertOne) SetProviderType(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateProviderType() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityUpsertOne) SetProviderKey(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateProviderKey() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *AuthIdentityUpsertOne) SetProviderSubject(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateProviderSubject() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetVerifiedAt sets the "verified_at" field. +func (u *AuthIdentityUpsertOne) SetVerifiedAt(v time.Time) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetVerifiedAt(v) + }) +} + +// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateVerifiedAt() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateVerifiedAt() + }) +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (u *AuthIdentityUpsertOne) ClearVerifiedAt() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearVerifiedAt() + }) +} + +// SetIssuer sets the "issuer" field. +func (u *AuthIdentityUpsertOne) SetIssuer(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetIssuer(v) + }) +} + +// UpdateIssuer sets the "issuer" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateIssuer() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateIssuer() + }) +} + +// ClearIssuer clears the value of the "issuer" field. +func (u *AuthIdentityUpsertOne) ClearIssuer() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearIssuer() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateMetadata() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *AuthIdentityUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *AuthIdentityUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// AuthIdentityCreateBulk is the builder for creating many AuthIdentity entities in bulk. +type AuthIdentityCreateBulk struct { + config + err error + builders []*AuthIdentityCreate + conflict []sql.ConflictOption +} + +// Save creates the AuthIdentity entities in the database. +func (_c *AuthIdentityCreateBulk) Save(ctx context.Context) ([]*AuthIdentity, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*AuthIdentity, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*AuthIdentityMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *AuthIdentityCreateBulk) SaveX(ctx context.Context) []*AuthIdentity { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentity.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertBulk { + _c.conflict = opts + return &AuthIdentityUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityUpsertBulk{ + create: _c, + } +} + +// AuthIdentityUpsertBulk is the builder for "upsert"-ing +// a bulk of AuthIdentity nodes. +type AuthIdentityUpsertBulk struct { + create *AuthIdentityCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityUpsertBulk) UpdateNewValues() *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(authidentity.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityUpsertBulk) Ignore() *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityUpsertBulk) DoNothing() *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreateBulk.OnConflict +// documentation for more info. +func (u *AuthIdentityUpsertBulk) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateUpdatedAt() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetUserID sets the "user_id" field. +func (u *AuthIdentityUpsertBulk) SetUserID(v int64) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateUserID() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUserID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityUpsertBulk) SetProviderType(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateProviderType() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityUpsertBulk) SetProviderKey(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateProviderKey() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *AuthIdentityUpsertBulk) SetProviderSubject(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateProviderSubject() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetVerifiedAt sets the "verified_at" field. +func (u *AuthIdentityUpsertBulk) SetVerifiedAt(v time.Time) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetVerifiedAt(v) + }) +} + +// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateVerifiedAt() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateVerifiedAt() + }) +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (u *AuthIdentityUpsertBulk) ClearVerifiedAt() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearVerifiedAt() + }) +} + +// SetIssuer sets the "issuer" field. +func (u *AuthIdentityUpsertBulk) SetIssuer(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetIssuer(v) + }) +} + +// UpdateIssuer sets the "issuer" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateIssuer() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateIssuer() + }) +} + +// ClearIssuer clears the value of the "issuer" field. +func (u *AuthIdentityUpsertBulk) ClearIssuer() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearIssuer() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateMetadata() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentity_delete.go b/backend/ent/authidentity_delete.go new file mode 100644 index 00000000..4f1f6f3c --- /dev/null +++ b/backend/ent/authidentity_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityDelete is the builder for deleting a AuthIdentity entity. +type AuthIdentityDelete struct { + config + hooks []Hook + mutation *AuthIdentityMutation +} + +// Where appends a list predicates to the AuthIdentityDelete builder. +func (_d *AuthIdentityDelete) Where(ps ...predicate.AuthIdentity) *AuthIdentityDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *AuthIdentityDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *AuthIdentityDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// AuthIdentityDeleteOne is the builder for deleting a single AuthIdentity entity. +type AuthIdentityDeleteOne struct { + _d *AuthIdentityDelete +} + +// Where appends a list predicates to the AuthIdentityDelete builder. +func (_d *AuthIdentityDeleteOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *AuthIdentityDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{authidentity.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentity_query.go b/backend/ent/authidentity_query.go new file mode 100644 index 00000000..ff27ef3c --- /dev/null +++ b/backend/ent/authidentity_query.go @@ -0,0 +1,797 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentityQuery is the builder for querying AuthIdentity entities. +type AuthIdentityQuery struct { + config + ctx *QueryContext + order []authidentity.OrderOption + inters []Interceptor + predicates []predicate.AuthIdentity + withUser *UserQuery + withChannels *AuthIdentityChannelQuery + withAdoptionDecisions *IdentityAdoptionDecisionQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the AuthIdentityQuery builder. +func (_q *AuthIdentityQuery) Where(ps ...predicate.AuthIdentity) *AuthIdentityQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *AuthIdentityQuery) Limit(limit int) *AuthIdentityQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *AuthIdentityQuery) Offset(offset int) *AuthIdentityQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *AuthIdentityQuery) Unique(unique bool) *AuthIdentityQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *AuthIdentityQuery) Order(o ...authidentity.OrderOption) *AuthIdentityQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryUser chains the current query on the "user" edge. +func (_q *AuthIdentityQuery) QueryUser() *UserQuery { + query := (&UserClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryChannels chains the current query on the "channels" edge. +func (_q *AuthIdentityQuery) QueryChannels() *AuthIdentityChannelQuery { + query := (&AuthIdentityChannelClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, selector), + sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryAdoptionDecisions chains the current query on the "adoption_decisions" edge. +func (_q *AuthIdentityQuery) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, selector), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first AuthIdentity entity from the query. +// Returns a *NotFoundError when no AuthIdentity was found. +func (_q *AuthIdentityQuery) First(ctx context.Context) (*AuthIdentity, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{authidentity.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *AuthIdentityQuery) FirstX(ctx context.Context) *AuthIdentity { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first AuthIdentity ID from the query. +// Returns a *NotFoundError when no AuthIdentity ID was found. +func (_q *AuthIdentityQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{authidentity.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *AuthIdentityQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single AuthIdentity entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one AuthIdentity entity is found. +// Returns a *NotFoundError when no AuthIdentity entities are found. +func (_q *AuthIdentityQuery) Only(ctx context.Context) (*AuthIdentity, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{authidentity.Label} + default: + return nil, &NotSingularError{authidentity.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *AuthIdentityQuery) OnlyX(ctx context.Context) *AuthIdentity { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only AuthIdentity ID in the query. +// Returns a *NotSingularError when more than one AuthIdentity ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *AuthIdentityQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{authidentity.Label} + default: + err = &NotSingularError{authidentity.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *AuthIdentityQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of AuthIdentities. +func (_q *AuthIdentityQuery) All(ctx context.Context) ([]*AuthIdentity, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*AuthIdentity, *AuthIdentityQuery]() + return withInterceptors[[]*AuthIdentity](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *AuthIdentityQuery) AllX(ctx context.Context) []*AuthIdentity { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of AuthIdentity IDs. +func (_q *AuthIdentityQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(authidentity.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *AuthIdentityQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *AuthIdentityQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *AuthIdentityQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *AuthIdentityQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *AuthIdentityQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the AuthIdentityQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *AuthIdentityQuery) Clone() *AuthIdentityQuery { + if _q == nil { + return nil + } + return &AuthIdentityQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]authidentity.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.AuthIdentity{}, _q.predicates...), + withUser: _q.withUser.Clone(), + withChannels: _q.withChannels.Clone(), + withAdoptionDecisions: _q.withAdoptionDecisions.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithUser tells the query-builder to eager-load the nodes that are connected to +// the "user" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityQuery) WithUser(opts ...func(*UserQuery)) *AuthIdentityQuery { + query := (&UserClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUser = query + return _q +} + +// WithChannels tells the query-builder to eager-load the nodes that are connected to +// the "channels" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityQuery) WithChannels(opts ...func(*AuthIdentityChannelQuery)) *AuthIdentityQuery { + query := (&AuthIdentityChannelClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withChannels = query + return _q +} + +// WithAdoptionDecisions tells the query-builder to eager-load the nodes that are connected to +// the "adoption_decisions" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityQuery) WithAdoptionDecisions(opts ...func(*IdentityAdoptionDecisionQuery)) *AuthIdentityQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAdoptionDecisions = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.AuthIdentity.Query(). +// GroupBy(authidentity.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *AuthIdentityQuery) GroupBy(field string, fields ...string) *AuthIdentityGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &AuthIdentityGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = authidentity.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.AuthIdentity.Query(). +// Select(authidentity.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *AuthIdentityQuery) Select(fields ...string) *AuthIdentitySelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &AuthIdentitySelect{AuthIdentityQuery: _q} + sbuild.label = authidentity.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AuthIdentitySelect configured with the given aggregations. +func (_q *AuthIdentityQuery) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *AuthIdentityQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !authidentity.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *AuthIdentityQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentity, error) { + var ( + nodes = []*AuthIdentity{} + _spec = _q.querySpec() + loadedTypes = [3]bool{ + _q.withUser != nil, + _q.withChannels != nil, + _q.withAdoptionDecisions != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*AuthIdentity).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &AuthIdentity{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withUser; query != nil { + if err := _q.loadUser(ctx, query, nodes, nil, + func(n *AuthIdentity, e *User) { n.Edges.User = e }); err != nil { + return nil, err + } + } + if query := _q.withChannels; query != nil { + if err := _q.loadChannels(ctx, query, nodes, + func(n *AuthIdentity) { n.Edges.Channels = []*AuthIdentityChannel{} }, + func(n *AuthIdentity, e *AuthIdentityChannel) { n.Edges.Channels = append(n.Edges.Channels, e) }); err != nil { + return nil, err + } + } + if query := _q.withAdoptionDecisions; query != nil { + if err := _q.loadAdoptionDecisions(ctx, query, nodes, + func(n *AuthIdentity) { n.Edges.AdoptionDecisions = []*IdentityAdoptionDecision{} }, + func(n *AuthIdentity, e *IdentityAdoptionDecision) { + n.Edges.AdoptionDecisions = append(n.Edges.AdoptionDecisions, e) + }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *AuthIdentityQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *User)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*AuthIdentity) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *AuthIdentityQuery) loadChannels(ctx context.Context, query *AuthIdentityChannelQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *AuthIdentityChannel)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*AuthIdentity) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(authidentitychannel.FieldIdentityID) + } + query.Where(predicate.AuthIdentityChannel(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(authidentity.ChannelsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.IdentityID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (_q *AuthIdentityQuery) loadAdoptionDecisions(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *IdentityAdoptionDecision)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*AuthIdentity) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(identityadoptiondecision.FieldIdentityID) + } + query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(authidentity.AdoptionDecisionsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.IdentityID + if fk == nil { + return fmt.Errorf(`foreign-key "identity_id" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} + +func (_q *AuthIdentityQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *AuthIdentityQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID) + for i := range fields { + if fields[i] != authidentity.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withUser != nil { + _spec.Node.AddColumnOnce(authidentity.FieldUserID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *AuthIdentityQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(authidentity.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = authidentity.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *AuthIdentityQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *AuthIdentityQuery) ForShare(opts ...sql.LockOption) *AuthIdentityQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// AuthIdentityGroupBy is the group-by builder for AuthIdentity entities. +type AuthIdentityGroupBy struct { + selector + build *AuthIdentityQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *AuthIdentityGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *AuthIdentityGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentityGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *AuthIdentityGroupBy) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// AuthIdentitySelect is the builder for selecting fields of AuthIdentity entities. +type AuthIdentitySelect struct { + *AuthIdentityQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *AuthIdentitySelect) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *AuthIdentitySelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentitySelect](ctx, _s.AuthIdentityQuery, _s, _s.inters, v) +} + +func (_s *AuthIdentitySelect) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/authidentity_update.go b/backend/ent/authidentity_update.go new file mode 100644 index 00000000..c457470b --- /dev/null +++ b/backend/ent/authidentity_update.go @@ -0,0 +1,923 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentityUpdate is the builder for updating AuthIdentity entities. +type AuthIdentityUpdate struct { + config + hooks []Hook + mutation *AuthIdentityMutation +} + +// Where appends a list predicates to the AuthIdentityUpdate builder. +func (_u *AuthIdentityUpdate) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityUpdate) SetUpdatedAt(v time.Time) *AuthIdentityUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *AuthIdentityUpdate) SetUserID(v int64) *AuthIdentityUpdate { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableUserID(v *int64) *AuthIdentityUpdate { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityUpdate) SetProviderType(v string) *AuthIdentityUpdate { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableProviderType(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityUpdate) SetProviderKey(v string) *AuthIdentityUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableProviderKey(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *AuthIdentityUpdate) SetProviderSubject(v string) *AuthIdentityUpdate { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableProviderSubject(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetVerifiedAt sets the "verified_at" field. +func (_u *AuthIdentityUpdate) SetVerifiedAt(v time.Time) *AuthIdentityUpdate { + _u.mutation.SetVerifiedAt(v) + return _u +} + +// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdate { + if v != nil { + _u.SetVerifiedAt(*v) + } + return _u +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (_u *AuthIdentityUpdate) ClearVerifiedAt() *AuthIdentityUpdate { + _u.mutation.ClearVerifiedAt() + return _u +} + +// SetIssuer sets the "issuer" field. +func (_u *AuthIdentityUpdate) SetIssuer(v string) *AuthIdentityUpdate { + _u.mutation.SetIssuer(v) + return _u +} + +// SetNillableIssuer sets the "issuer" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableIssuer(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetIssuer(*v) + } + return _u +} + +// ClearIssuer clears the value of the "issuer" field. +func (_u *AuthIdentityUpdate) ClearIssuer() *AuthIdentityUpdate { + _u.mutation.ClearIssuer() + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityUpdate { + _u.mutation.SetMetadata(v) + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *AuthIdentityUpdate) SetUser(v *User) *AuthIdentityUpdate { + return _u.SetUserID(v.ID) +} + +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs. +func (_u *AuthIdentityUpdate) AddChannelIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.AddChannelIDs(ids...) + return _u +} + +// AddChannels adds the "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddChannelIDs(ids...) +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (_u *AuthIdentityUpdate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.AddAdoptionDecisionIDs(ids...) + return _u +} + +// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAdoptionDecisionIDs(ids...) +} + +// Mutation returns the AuthIdentityMutation object of the builder. +func (_u *AuthIdentityUpdate) Mutation() *AuthIdentityMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *AuthIdentityUpdate) ClearUser() *AuthIdentityUpdate { + _u.mutation.ClearUser() + return _u +} + +// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdate) ClearChannels() *AuthIdentityUpdate { + _u.mutation.ClearChannels() + return _u +} + +// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs. +func (_u *AuthIdentityUpdate) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.RemoveChannelIDs(ids...) + return _u +} + +// RemoveChannels removes "channels" edges to AuthIdentityChannel entities. +func (_u *AuthIdentityUpdate) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveChannelIDs(ids...) +} + +// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdate) ClearAdoptionDecisions() *AuthIdentityUpdate { + _u.mutation.ClearAdoptionDecisions() + return _u +} + +// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs. +func (_u *AuthIdentityUpdate) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.RemoveAdoptionDecisionIDs(ids...) + return _u +} + +// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities. +func (_u *AuthIdentityUpdate) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAdoptionDecisionIDs(ids...) +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *AuthIdentityUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *AuthIdentityUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentity.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityUpdate) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentity.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentity.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := authidentity.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`) + } + return nil +} + +func (_u *AuthIdentityUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentity.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.VerifiedAt(); ok { + _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value) + } + if _u.mutation.VerifiedAtCleared() { + _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.Issuer(); ok { + _spec.SetField(authidentity.FieldIssuer, field.TypeString, value) + } + if _u.mutation.IssuerCleared() { + _spec.ClearField(authidentity.FieldIssuer, field.TypeString) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentity.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// AuthIdentityUpdateOne is the builder for updating a single AuthIdentity entity. +type AuthIdentityUpdateOne struct { + config + fields []string + hooks []Hook + mutation *AuthIdentityMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *AuthIdentityUpdateOne) SetUserID(v int64) *AuthIdentityUpdateOne { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableUserID(v *int64) *AuthIdentityUpdateOne { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityUpdateOne) SetProviderType(v string) *AuthIdentityUpdateOne { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableProviderType(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityUpdateOne) SetProviderKey(v string) *AuthIdentityUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *AuthIdentityUpdateOne) SetProviderSubject(v string) *AuthIdentityUpdateOne { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableProviderSubject(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetVerifiedAt sets the "verified_at" field. +func (_u *AuthIdentityUpdateOne) SetVerifiedAt(v time.Time) *AuthIdentityUpdateOne { + _u.mutation.SetVerifiedAt(v) + return _u +} + +// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdateOne { + if v != nil { + _u.SetVerifiedAt(*v) + } + return _u +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (_u *AuthIdentityUpdateOne) ClearVerifiedAt() *AuthIdentityUpdateOne { + _u.mutation.ClearVerifiedAt() + return _u +} + +// SetIssuer sets the "issuer" field. +func (_u *AuthIdentityUpdateOne) SetIssuer(v string) *AuthIdentityUpdateOne { + _u.mutation.SetIssuer(v) + return _u +} + +// SetNillableIssuer sets the "issuer" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableIssuer(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetIssuer(*v) + } + return _u +} + +// ClearIssuer clears the value of the "issuer" field. +func (_u *AuthIdentityUpdateOne) ClearIssuer() *AuthIdentityUpdateOne { + _u.mutation.ClearIssuer() + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpdateOne { + _u.mutation.SetMetadata(v) + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *AuthIdentityUpdateOne) SetUser(v *User) *AuthIdentityUpdateOne { + return _u.SetUserID(v.ID) +} + +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs. +func (_u *AuthIdentityUpdateOne) AddChannelIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.AddChannelIDs(ids...) + return _u +} + +// AddChannels adds the "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdateOne) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddChannelIDs(ids...) +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (_u *AuthIdentityUpdateOne) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.AddAdoptionDecisionIDs(ids...) + return _u +} + +// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdateOne) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAdoptionDecisionIDs(ids...) +} + +// Mutation returns the AuthIdentityMutation object of the builder. +func (_u *AuthIdentityUpdateOne) Mutation() *AuthIdentityMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *AuthIdentityUpdateOne) ClearUser() *AuthIdentityUpdateOne { + _u.mutation.ClearUser() + return _u +} + +// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdateOne) ClearChannels() *AuthIdentityUpdateOne { + _u.mutation.ClearChannels() + return _u +} + +// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs. +func (_u *AuthIdentityUpdateOne) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.RemoveChannelIDs(ids...) + return _u +} + +// RemoveChannels removes "channels" edges to AuthIdentityChannel entities. +func (_u *AuthIdentityUpdateOne) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveChannelIDs(ids...) +} + +// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdateOne) ClearAdoptionDecisions() *AuthIdentityUpdateOne { + _u.mutation.ClearAdoptionDecisions() + return _u +} + +// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs. +func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.RemoveAdoptionDecisionIDs(ids...) + return _u +} + +// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities. +func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAdoptionDecisionIDs(ids...) +} + +// Where appends a list predicates to the AuthIdentityUpdate builder. +func (_u *AuthIdentityUpdateOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *AuthIdentityUpdateOne) Select(field string, fields ...string) *AuthIdentityUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated AuthIdentity entity. +func (_u *AuthIdentityUpdateOne) Save(ctx context.Context) (*AuthIdentity, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityUpdateOne) SaveX(ctx context.Context) *AuthIdentity { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *AuthIdentityUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentity.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityUpdateOne) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentity.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentity.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := authidentity.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`) + } + return nil +} + +func (_u *AuthIdentityUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentity, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentity.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID) + for _, f := range fields { + if !authidentity.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != authidentity.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentity.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.VerifiedAt(); ok { + _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value) + } + if _u.mutation.VerifiedAtCleared() { + _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.Issuer(); ok { + _spec.SetField(authidentity.FieldIssuer, field.TypeString, value) + } + if _u.mutation.IssuerCleared() { + _spec.ClearField(authidentity.FieldIssuer, field.TypeString) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &AuthIdentity{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentity.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/authidentitychannel.go b/backend/ent/authidentitychannel.go new file mode 100644 index 00000000..1ff3e5d1 --- /dev/null +++ b/backend/ent/authidentitychannel.go @@ -0,0 +1,228 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" +) + +// AuthIdentityChannel is the model entity for the AuthIdentityChannel schema. +type AuthIdentityChannel struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // IdentityID holds the value of the "identity_id" field. + IdentityID int64 `json:"identity_id,omitempty"` + // ProviderType holds the value of the "provider_type" field. + ProviderType string `json:"provider_type,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey string `json:"provider_key,omitempty"` + // Channel holds the value of the "channel" field. + Channel string `json:"channel,omitempty"` + // ChannelAppID holds the value of the "channel_app_id" field. + ChannelAppID string `json:"channel_app_id,omitempty"` + // ChannelSubject holds the value of the "channel_subject" field. + ChannelSubject string `json:"channel_subject,omitempty"` + // Metadata holds the value of the "metadata" field. + Metadata map[string]interface{} `json:"metadata,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the AuthIdentityChannelQuery when eager-loading is set. + Edges AuthIdentityChannelEdges `json:"edges"` + selectValues sql.SelectValues +} + +// AuthIdentityChannelEdges holds the relations/edges for other nodes in the graph. +type AuthIdentityChannelEdges struct { + // Identity holds the value of the identity edge. + Identity *AuthIdentity `json:"identity,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// IdentityOrErr returns the Identity value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e AuthIdentityChannelEdges) IdentityOrErr() (*AuthIdentity, error) { + if e.Identity != nil { + return e.Identity, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: authidentity.Label} + } + return nil, &NotLoadedError{edge: "identity"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*AuthIdentityChannel) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case authidentitychannel.FieldMetadata: + values[i] = new([]byte) + case authidentitychannel.FieldID, authidentitychannel.FieldIdentityID: + values[i] = new(sql.NullInt64) + case authidentitychannel.FieldProviderType, authidentitychannel.FieldProviderKey, authidentitychannel.FieldChannel, authidentitychannel.FieldChannelAppID, authidentitychannel.FieldChannelSubject: + values[i] = new(sql.NullString) + case authidentitychannel.FieldCreatedAt, authidentitychannel.FieldUpdatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the AuthIdentityChannel fields. +func (_m *AuthIdentityChannel) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case authidentitychannel.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case authidentitychannel.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case authidentitychannel.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case authidentitychannel.FieldIdentityID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field identity_id", values[i]) + } else if value.Valid { + _m.IdentityID = value.Int64 + } + case authidentitychannel.FieldProviderType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_type", values[i]) + } else if value.Valid { + _m.ProviderType = value.String + } + case authidentitychannel.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = value.String + } + case authidentitychannel.FieldChannel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field channel", values[i]) + } else if value.Valid { + _m.Channel = value.String + } + case authidentitychannel.FieldChannelAppID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field channel_app_id", values[i]) + } else if value.Valid { + _m.ChannelAppID = value.String + } + case authidentitychannel.FieldChannelSubject: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field channel_subject", values[i]) + } else if value.Valid { + _m.ChannelSubject = value.String + } + case authidentitychannel.FieldMetadata: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field metadata", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Metadata); err != nil { + return fmt.Errorf("unmarshal field metadata: %w", err) + } + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentityChannel. +// This includes values selected through modifiers, order, etc. +func (_m *AuthIdentityChannel) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryIdentity queries the "identity" edge of the AuthIdentityChannel entity. +func (_m *AuthIdentityChannel) QueryIdentity() *AuthIdentityQuery { + return NewAuthIdentityChannelClient(_m.config).QueryIdentity(_m) +} + +// Update returns a builder for updating this AuthIdentityChannel. +// Note that you need to call AuthIdentityChannel.Unwrap() before calling this method if this AuthIdentityChannel +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *AuthIdentityChannel) Update() *AuthIdentityChannelUpdateOne { + return NewAuthIdentityChannelClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the AuthIdentityChannel entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *AuthIdentityChannel) Unwrap() *AuthIdentityChannel { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: AuthIdentityChannel is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *AuthIdentityChannel) String() string { + var builder strings.Builder + builder.WriteString("AuthIdentityChannel(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("identity_id=") + builder.WriteString(fmt.Sprintf("%v", _m.IdentityID)) + builder.WriteString(", ") + builder.WriteString("provider_type=") + builder.WriteString(_m.ProviderType) + builder.WriteString(", ") + builder.WriteString("provider_key=") + builder.WriteString(_m.ProviderKey) + builder.WriteString(", ") + builder.WriteString("channel=") + builder.WriteString(_m.Channel) + builder.WriteString(", ") + builder.WriteString("channel_app_id=") + builder.WriteString(_m.ChannelAppID) + builder.WriteString(", ") + builder.WriteString("channel_subject=") + builder.WriteString(_m.ChannelSubject) + builder.WriteString(", ") + builder.WriteString("metadata=") + builder.WriteString(fmt.Sprintf("%v", _m.Metadata)) + builder.WriteByte(')') + return builder.String() +} + +// AuthIdentityChannels is a parsable slice of AuthIdentityChannel. +type AuthIdentityChannels []*AuthIdentityChannel diff --git a/backend/ent/authidentitychannel/authidentitychannel.go b/backend/ent/authidentitychannel/authidentitychannel.go new file mode 100644 index 00000000..7dcc98bb --- /dev/null +++ b/backend/ent/authidentitychannel/authidentitychannel.go @@ -0,0 +1,153 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentitychannel + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the authidentitychannel type in the database. + Label = "auth_identity_channel" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldIdentityID holds the string denoting the identity_id field in the database. + FieldIdentityID = "identity_id" + // FieldProviderType holds the string denoting the provider_type field in the database. + FieldProviderType = "provider_type" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" + // FieldChannel holds the string denoting the channel field in the database. + FieldChannel = "channel" + // FieldChannelAppID holds the string denoting the channel_app_id field in the database. + FieldChannelAppID = "channel_app_id" + // FieldChannelSubject holds the string denoting the channel_subject field in the database. + FieldChannelSubject = "channel_subject" + // FieldMetadata holds the string denoting the metadata field in the database. + FieldMetadata = "metadata" + // EdgeIdentity holds the string denoting the identity edge name in mutations. + EdgeIdentity = "identity" + // Table holds the table name of the authidentitychannel in the database. + Table = "auth_identity_channels" + // IdentityTable is the table that holds the identity relation/edge. + IdentityTable = "auth_identity_channels" + // IdentityInverseTable is the table name for the AuthIdentity entity. + // It exists in this package in order to avoid circular dependency with the "authidentity" package. + IdentityInverseTable = "auth_identities" + // IdentityColumn is the table column denoting the identity relation/edge. + IdentityColumn = "identity_id" +) + +// Columns holds all SQL columns for authidentitychannel fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldIdentityID, + FieldProviderType, + FieldProviderKey, + FieldChannel, + FieldChannelAppID, + FieldChannelSubject, + FieldMetadata, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + ProviderTypeValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error + // ChannelValidator is a validator for the "channel" field. It is called by the builders before save. + ChannelValidator func(string) error + // ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save. + ChannelAppIDValidator func(string) error + // ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save. + ChannelSubjectValidator func(string) error + // DefaultMetadata holds the default value on creation for the "metadata" field. + DefaultMetadata func() map[string]interface{} +) + +// OrderOption defines the ordering options for the AuthIdentityChannel queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByIdentityID orders the results by the identity_id field. +func ByIdentityID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIdentityID, opts...).ToFunc() +} + +// ByProviderType orders the results by the provider_type field. +func ByProviderType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderType, opts...).ToFunc() +} + +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + +// ByChannel orders the results by the channel field. +func ByChannel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChannel, opts...).ToFunc() +} + +// ByChannelAppID orders the results by the channel_app_id field. +func ByChannelAppID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChannelAppID, opts...).ToFunc() +} + +// ByChannelSubject orders the results by the channel_subject field. +func ByChannelSubject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChannelSubject, opts...).ToFunc() +} + +// ByIdentityField orders the results by identity field. +func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...)) + } +} +func newIdentityStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(IdentityInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) +} diff --git a/backend/ent/authidentitychannel/where.go b/backend/ent/authidentitychannel/where.go new file mode 100644 index 00000000..827dc384 --- /dev/null +++ b/backend/ent/authidentitychannel/where.go @@ -0,0 +1,559 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentitychannel + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ. +func IdentityID(v int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v)) +} + +// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ. +func ProviderType(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v)) +} + +// Channel applies equality check predicate on the "channel" field. It's identical to ChannelEQ. +func Channel(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v)) +} + +// ChannelAppID applies equality check predicate on the "channel_app_id" field. It's identical to ChannelAppIDEQ. +func ChannelAppID(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v)) +} + +// ChannelSubject applies equality check predicate on the "channel_subject" field. It's identical to ChannelSubjectEQ. +func ChannelSubject(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// IdentityIDEQ applies the EQ predicate on the "identity_id" field. +func IdentityIDEQ(v int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v)) +} + +// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field. +func IdentityIDNEQ(v int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldIdentityID, v)) +} + +// IdentityIDIn applies the In predicate on the "identity_id" field. +func IdentityIDIn(vs ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldIdentityID, vs...)) +} + +// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field. +func IdentityIDNotIn(vs ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldIdentityID, vs...)) +} + +// ProviderTypeEQ applies the EQ predicate on the "provider_type" field. +func ProviderTypeEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field. +func ProviderTypeNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderType, v)) +} + +// ProviderTypeIn applies the In predicate on the "provider_type" field. +func ProviderTypeIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderType, vs...)) +} + +// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field. +func ProviderTypeNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderType, vs...)) +} + +// ProviderTypeGT applies the GT predicate on the "provider_type" field. +func ProviderTypeGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderType, v)) +} + +// ProviderTypeGTE applies the GTE predicate on the "provider_type" field. +func ProviderTypeGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderType, v)) +} + +// ProviderTypeLT applies the LT predicate on the "provider_type" field. +func ProviderTypeLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderType, v)) +} + +// ProviderTypeLTE applies the LTE predicate on the "provider_type" field. +func ProviderTypeLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderType, v)) +} + +// ProviderTypeContains applies the Contains predicate on the "provider_type" field. +func ProviderTypeContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderType, v)) +} + +// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field. +func ProviderTypeHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderType, v)) +} + +// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field. +func ProviderTypeHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderType, v)) +} + +// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field. +func ProviderTypeEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderType, v)) +} + +// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field. +func ProviderTypeContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderType, v)) +} + +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderKey, v)) +} + +// ChannelEQ applies the EQ predicate on the "channel" field. +func ChannelEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v)) +} + +// ChannelNEQ applies the NEQ predicate on the "channel" field. +func ChannelNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannel, v)) +} + +// ChannelIn applies the In predicate on the "channel" field. +func ChannelIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannel, vs...)) +} + +// ChannelNotIn applies the NotIn predicate on the "channel" field. +func ChannelNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannel, vs...)) +} + +// ChannelGT applies the GT predicate on the "channel" field. +func ChannelGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannel, v)) +} + +// ChannelGTE applies the GTE predicate on the "channel" field. +func ChannelGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannel, v)) +} + +// ChannelLT applies the LT predicate on the "channel" field. +func ChannelLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannel, v)) +} + +// ChannelLTE applies the LTE predicate on the "channel" field. +func ChannelLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannel, v)) +} + +// ChannelContains applies the Contains predicate on the "channel" field. +func ChannelContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannel, v)) +} + +// ChannelHasPrefix applies the HasPrefix predicate on the "channel" field. +func ChannelHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannel, v)) +} + +// ChannelHasSuffix applies the HasSuffix predicate on the "channel" field. +func ChannelHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannel, v)) +} + +// ChannelEqualFold applies the EqualFold predicate on the "channel" field. +func ChannelEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannel, v)) +} + +// ChannelContainsFold applies the ContainsFold predicate on the "channel" field. +func ChannelContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannel, v)) +} + +// ChannelAppIDEQ applies the EQ predicate on the "channel_app_id" field. +func ChannelAppIDEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v)) +} + +// ChannelAppIDNEQ applies the NEQ predicate on the "channel_app_id" field. +func ChannelAppIDNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelAppID, v)) +} + +// ChannelAppIDIn applies the In predicate on the "channel_app_id" field. +func ChannelAppIDIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelAppID, vs...)) +} + +// ChannelAppIDNotIn applies the NotIn predicate on the "channel_app_id" field. +func ChannelAppIDNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelAppID, vs...)) +} + +// ChannelAppIDGT applies the GT predicate on the "channel_app_id" field. +func ChannelAppIDGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelAppID, v)) +} + +// ChannelAppIDGTE applies the GTE predicate on the "channel_app_id" field. +func ChannelAppIDGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelAppID, v)) +} + +// ChannelAppIDLT applies the LT predicate on the "channel_app_id" field. +func ChannelAppIDLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelAppID, v)) +} + +// ChannelAppIDLTE applies the LTE predicate on the "channel_app_id" field. +func ChannelAppIDLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelAppID, v)) +} + +// ChannelAppIDContains applies the Contains predicate on the "channel_app_id" field. +func ChannelAppIDContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelAppID, v)) +} + +// ChannelAppIDHasPrefix applies the HasPrefix predicate on the "channel_app_id" field. +func ChannelAppIDHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelAppID, v)) +} + +// ChannelAppIDHasSuffix applies the HasSuffix predicate on the "channel_app_id" field. +func ChannelAppIDHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelAppID, v)) +} + +// ChannelAppIDEqualFold applies the EqualFold predicate on the "channel_app_id" field. +func ChannelAppIDEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelAppID, v)) +} + +// ChannelAppIDContainsFold applies the ContainsFold predicate on the "channel_app_id" field. +func ChannelAppIDContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelAppID, v)) +} + +// ChannelSubjectEQ applies the EQ predicate on the "channel_subject" field. +func ChannelSubjectEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v)) +} + +// ChannelSubjectNEQ applies the NEQ predicate on the "channel_subject" field. +func ChannelSubjectNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelSubject, v)) +} + +// ChannelSubjectIn applies the In predicate on the "channel_subject" field. +func ChannelSubjectIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelSubject, vs...)) +} + +// ChannelSubjectNotIn applies the NotIn predicate on the "channel_subject" field. +func ChannelSubjectNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelSubject, vs...)) +} + +// ChannelSubjectGT applies the GT predicate on the "channel_subject" field. +func ChannelSubjectGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelSubject, v)) +} + +// ChannelSubjectGTE applies the GTE predicate on the "channel_subject" field. +func ChannelSubjectGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelSubject, v)) +} + +// ChannelSubjectLT applies the LT predicate on the "channel_subject" field. +func ChannelSubjectLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelSubject, v)) +} + +// ChannelSubjectLTE applies the LTE predicate on the "channel_subject" field. +func ChannelSubjectLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelSubject, v)) +} + +// ChannelSubjectContains applies the Contains predicate on the "channel_subject" field. +func ChannelSubjectContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelSubject, v)) +} + +// ChannelSubjectHasPrefix applies the HasPrefix predicate on the "channel_subject" field. +func ChannelSubjectHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelSubject, v)) +} + +// ChannelSubjectHasSuffix applies the HasSuffix predicate on the "channel_subject" field. +func ChannelSubjectHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelSubject, v)) +} + +// ChannelSubjectEqualFold applies the EqualFold predicate on the "channel_subject" field. +func ChannelSubjectEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelSubject, v)) +} + +// ChannelSubjectContainsFold applies the ContainsFold predicate on the "channel_subject" field. +func ChannelSubjectContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelSubject, v)) +} + +// HasIdentity applies the HasEdge predicate on the "identity" edge. +func HasIdentity() predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates). +func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(func(s *sql.Selector) { + step := newIdentityStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.AuthIdentityChannel) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.NotPredicates(p)) +} diff --git a/backend/ent/authidentitychannel_create.go b/backend/ent/authidentitychannel_create.go new file mode 100644 index 00000000..4ce28479 --- /dev/null +++ b/backend/ent/authidentitychannel_create.go @@ -0,0 +1,932 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" +) + +// AuthIdentityChannelCreate is the builder for creating a AuthIdentityChannel entity. +type AuthIdentityChannelCreate struct { + config + mutation *AuthIdentityChannelMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *AuthIdentityChannelCreate) SetCreatedAt(v time.Time) *AuthIdentityChannelCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *AuthIdentityChannelCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityChannelCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *AuthIdentityChannelCreate) SetUpdatedAt(v time.Time) *AuthIdentityChannelCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *AuthIdentityChannelCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityChannelCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetIdentityID sets the "identity_id" field. +func (_c *AuthIdentityChannelCreate) SetIdentityID(v int64) *AuthIdentityChannelCreate { + _c.mutation.SetIdentityID(v) + return _c +} + +// SetProviderType sets the "provider_type" field. +func (_c *AuthIdentityChannelCreate) SetProviderType(v string) *AuthIdentityChannelCreate { + _c.mutation.SetProviderType(v) + return _c +} + +// SetProviderKey sets the "provider_key" field. +func (_c *AuthIdentityChannelCreate) SetProviderKey(v string) *AuthIdentityChannelCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetChannel sets the "channel" field. +func (_c *AuthIdentityChannelCreate) SetChannel(v string) *AuthIdentityChannelCreate { + _c.mutation.SetChannel(v) + return _c +} + +// SetChannelAppID sets the "channel_app_id" field. +func (_c *AuthIdentityChannelCreate) SetChannelAppID(v string) *AuthIdentityChannelCreate { + _c.mutation.SetChannelAppID(v) + return _c +} + +// SetChannelSubject sets the "channel_subject" field. +func (_c *AuthIdentityChannelCreate) SetChannelSubject(v string) *AuthIdentityChannelCreate { + _c.mutation.SetChannelSubject(v) + return _c +} + +// SetMetadata sets the "metadata" field. +func (_c *AuthIdentityChannelCreate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelCreate { + _c.mutation.SetMetadata(v) + return _c +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_c *AuthIdentityChannelCreate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelCreate { + return _c.SetIdentityID(v.ID) +} + +// Mutation returns the AuthIdentityChannelMutation object of the builder. +func (_c *AuthIdentityChannelCreate) Mutation() *AuthIdentityChannelMutation { + return _c.mutation +} + +// Save creates the AuthIdentityChannel in the database. +func (_c *AuthIdentityChannelCreate) Save(ctx context.Context) (*AuthIdentityChannel, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *AuthIdentityChannelCreate) SaveX(ctx context.Context) *AuthIdentityChannel { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityChannelCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityChannelCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *AuthIdentityChannelCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := authidentitychannel.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := authidentitychannel.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.Metadata(); !ok { + v := authidentitychannel.DefaultMetadata() + _c.mutation.SetMetadata(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *AuthIdentityChannelCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.updated_at"`)} + } + if _, ok := _c.mutation.IdentityID(); !ok { + return &ValidationError{Name: "identity_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.identity_id"`)} + } + if _, ok := _c.mutation.ProviderType(); !ok { + return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_type"`)} + } + if v, ok := _c.mutation.ProviderType(); ok { + if err := authidentitychannel.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderKey(); !ok { + return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_key"`)} + } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := authidentitychannel.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)} + } + } + if _, ok := _c.mutation.Channel(); !ok { + return &ValidationError{Name: "channel", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel"`)} + } + if v, ok := _c.mutation.Channel(); ok { + if err := authidentitychannel.ChannelValidator(v); err != nil { + return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)} + } + } + if _, ok := _c.mutation.ChannelAppID(); !ok { + return &ValidationError{Name: "channel_app_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_app_id"`)} + } + if v, ok := _c.mutation.ChannelAppID(); ok { + if err := authidentitychannel.ChannelAppIDValidator(v); err != nil { + return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)} + } + } + if _, ok := _c.mutation.ChannelSubject(); !ok { + return &ValidationError{Name: "channel_subject", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_subject"`)} + } + if v, ok := _c.mutation.ChannelSubject(); ok { + if err := authidentitychannel.ChannelSubjectValidator(v); err != nil { + return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)} + } + } + if _, ok := _c.mutation.Metadata(); !ok { + return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentityChannel.metadata"`)} + } + if len(_c.mutation.IdentityIDs()) == 0 { + return &ValidationError{Name: "identity", err: errors.New(`ent: missing required edge "AuthIdentityChannel.identity"`)} + } + return nil +} + +func (_c *AuthIdentityChannelCreate) sqlSave(ctx context.Context) (*AuthIdentityChannel, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *AuthIdentityChannelCreate) createSpec() (*AuthIdentityChannel, *sqlgraph.CreateSpec) { + var ( + _node = &AuthIdentityChannel{config: _c.config} + _spec = sqlgraph.NewCreateSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(authidentitychannel.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.ProviderType(); ok { + _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value) + _node.ProviderType = value + } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = value + } + if value, ok := _c.mutation.Channel(); ok { + _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value) + _node.Channel = value + } + if value, ok := _c.mutation.ChannelAppID(); ok { + _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value) + _node.ChannelAppID = value + } + if value, ok := _c.mutation.ChannelSubject(); ok { + _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value) + _node.ChannelSubject = value + } + if value, ok := _c.mutation.Metadata(); ok { + _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value) + _node.Metadata = value + } + if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.IdentityID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentityChannel.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityChannelUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityChannelCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertOne { + _c.conflict = opts + return &AuthIdentityChannelUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityChannelCreate) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityChannelUpsertOne{ + create: _c, + } +} + +type ( + // AuthIdentityChannelUpsertOne is the builder for "upsert"-ing + // one AuthIdentityChannel node. + AuthIdentityChannelUpsertOne struct { + create *AuthIdentityChannelCreate + } + + // AuthIdentityChannelUpsert is the "OnConflict" setter. + AuthIdentityChannelUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityChannelUpsert) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateUpdatedAt() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldUpdatedAt) + return u +} + +// SetIdentityID sets the "identity_id" field. +func (u *AuthIdentityChannelUpsert) SetIdentityID(v int64) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldIdentityID, v) + return u +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateIdentityID() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldIdentityID) + return u +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityChannelUpsert) SetProviderType(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldProviderType, v) + return u +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateProviderType() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldProviderType) + return u +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityChannelUpsert) SetProviderKey(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateProviderKey() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldProviderKey) + return u +} + +// SetChannel sets the "channel" field. +func (u *AuthIdentityChannelUpsert) SetChannel(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldChannel, v) + return u +} + +// UpdateChannel sets the "channel" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateChannel() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldChannel) + return u +} + +// SetChannelAppID sets the "channel_app_id" field. +func (u *AuthIdentityChannelUpsert) SetChannelAppID(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldChannelAppID, v) + return u +} + +// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateChannelAppID() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldChannelAppID) + return u +} + +// SetChannelSubject sets the "channel_subject" field. +func (u *AuthIdentityChannelUpsert) SetChannelSubject(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldChannelSubject, v) + return u +} + +// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateChannelSubject() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldChannelSubject) + return u +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityChannelUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldMetadata, v) + return u +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateMetadata() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldMetadata) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertOne) UpdateNewValues() *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(authidentitychannel.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertOne) Ignore() *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityChannelUpsertOne) DoNothing() *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreate.OnConflict +// documentation for more info. +func (u *AuthIdentityChannelUpsertOne) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityChannelUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityChannelUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateUpdatedAt() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *AuthIdentityChannelUpsertOne) SetIdentityID(v int64) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateIdentityID() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateIdentityID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityChannelUpsertOne) SetProviderType(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateProviderType() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityChannelUpsertOne) SetProviderKey(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateProviderKey() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderKey() + }) +} + +// SetChannel sets the "channel" field. +func (u *AuthIdentityChannelUpsertOne) SetChannel(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannel(v) + }) +} + +// UpdateChannel sets the "channel" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateChannel() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannel() + }) +} + +// SetChannelAppID sets the "channel_app_id" field. +func (u *AuthIdentityChannelUpsertOne) SetChannelAppID(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelAppID(v) + }) +} + +// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateChannelAppID() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelAppID() + }) +} + +// SetChannelSubject sets the "channel_subject" field. +func (u *AuthIdentityChannelUpsertOne) SetChannelSubject(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelSubject(v) + }) +} + +// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateChannelSubject() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelSubject() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityChannelUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateMetadata() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityChannelUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityChannelCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityChannelUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *AuthIdentityChannelUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *AuthIdentityChannelUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// AuthIdentityChannelCreateBulk is the builder for creating many AuthIdentityChannel entities in bulk. +type AuthIdentityChannelCreateBulk struct { + config + err error + builders []*AuthIdentityChannelCreate + conflict []sql.ConflictOption +} + +// Save creates the AuthIdentityChannel entities in the database. +func (_c *AuthIdentityChannelCreateBulk) Save(ctx context.Context) ([]*AuthIdentityChannel, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*AuthIdentityChannel, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*AuthIdentityChannelMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *AuthIdentityChannelCreateBulk) SaveX(ctx context.Context) []*AuthIdentityChannel { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityChannelCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityChannelCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentityChannel.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityChannelUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityChannelCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertBulk { + _c.conflict = opts + return &AuthIdentityChannelUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityChannelCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityChannelUpsertBulk{ + create: _c, + } +} + +// AuthIdentityChannelUpsertBulk is the builder for "upsert"-ing +// a bulk of AuthIdentityChannel nodes. +type AuthIdentityChannelUpsertBulk struct { + create *AuthIdentityChannelCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertBulk) UpdateNewValues() *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(authidentitychannel.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertBulk) Ignore() *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityChannelUpsertBulk) DoNothing() *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreateBulk.OnConflict +// documentation for more info. +func (u *AuthIdentityChannelUpsertBulk) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityChannelUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityChannelUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateUpdatedAt() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *AuthIdentityChannelUpsertBulk) SetIdentityID(v int64) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateIdentityID() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateIdentityID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityChannelUpsertBulk) SetProviderType(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateProviderType() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityChannelUpsertBulk) SetProviderKey(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateProviderKey() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderKey() + }) +} + +// SetChannel sets the "channel" field. +func (u *AuthIdentityChannelUpsertBulk) SetChannel(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannel(v) + }) +} + +// UpdateChannel sets the "channel" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateChannel() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannel() + }) +} + +// SetChannelAppID sets the "channel_app_id" field. +func (u *AuthIdentityChannelUpsertBulk) SetChannelAppID(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelAppID(v) + }) +} + +// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateChannelAppID() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelAppID() + }) +} + +// SetChannelSubject sets the "channel_subject" field. +func (u *AuthIdentityChannelUpsertBulk) SetChannelSubject(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelSubject(v) + }) +} + +// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateChannelSubject() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelSubject() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityChannelUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateMetadata() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityChannelUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityChannelCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityChannelCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityChannelUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentitychannel_delete.go b/backend/ent/authidentitychannel_delete.go new file mode 100644 index 00000000..1a4acac5 --- /dev/null +++ b/backend/ent/authidentitychannel_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityChannelDelete is the builder for deleting a AuthIdentityChannel entity. +type AuthIdentityChannelDelete struct { + config + hooks []Hook + mutation *AuthIdentityChannelMutation +} + +// Where appends a list predicates to the AuthIdentityChannelDelete builder. +func (_d *AuthIdentityChannelDelete) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *AuthIdentityChannelDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityChannelDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *AuthIdentityChannelDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// AuthIdentityChannelDeleteOne is the builder for deleting a single AuthIdentityChannel entity. +type AuthIdentityChannelDeleteOne struct { + _d *AuthIdentityChannelDelete +} + +// Where appends a list predicates to the AuthIdentityChannelDelete builder. +func (_d *AuthIdentityChannelDeleteOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *AuthIdentityChannelDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{authidentitychannel.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityChannelDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentitychannel_query.go b/backend/ent/authidentitychannel_query.go new file mode 100644 index 00000000..7a202b7f --- /dev/null +++ b/backend/ent/authidentitychannel_query.go @@ -0,0 +1,643 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityChannelQuery is the builder for querying AuthIdentityChannel entities. +type AuthIdentityChannelQuery struct { + config + ctx *QueryContext + order []authidentitychannel.OrderOption + inters []Interceptor + predicates []predicate.AuthIdentityChannel + withIdentity *AuthIdentityQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the AuthIdentityChannelQuery builder. +func (_q *AuthIdentityChannelQuery) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *AuthIdentityChannelQuery) Limit(limit int) *AuthIdentityChannelQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *AuthIdentityChannelQuery) Offset(offset int) *AuthIdentityChannelQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *AuthIdentityChannelQuery) Unique(unique bool) *AuthIdentityChannelQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *AuthIdentityChannelQuery) Order(o ...authidentitychannel.OrderOption) *AuthIdentityChannelQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryIdentity chains the current query on the "identity" edge. +func (_q *AuthIdentityChannelQuery) QueryIdentity() *AuthIdentityQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, selector), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first AuthIdentityChannel entity from the query. +// Returns a *NotFoundError when no AuthIdentityChannel was found. +func (_q *AuthIdentityChannelQuery) First(ctx context.Context) (*AuthIdentityChannel, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{authidentitychannel.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) FirstX(ctx context.Context) *AuthIdentityChannel { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first AuthIdentityChannel ID from the query. +// Returns a *NotFoundError when no AuthIdentityChannel ID was found. +func (_q *AuthIdentityChannelQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{authidentitychannel.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single AuthIdentityChannel entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one AuthIdentityChannel entity is found. +// Returns a *NotFoundError when no AuthIdentityChannel entities are found. +func (_q *AuthIdentityChannelQuery) Only(ctx context.Context) (*AuthIdentityChannel, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{authidentitychannel.Label} + default: + return nil, &NotSingularError{authidentitychannel.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) OnlyX(ctx context.Context) *AuthIdentityChannel { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only AuthIdentityChannel ID in the query. +// Returns a *NotSingularError when more than one AuthIdentityChannel ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *AuthIdentityChannelQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{authidentitychannel.Label} + default: + err = &NotSingularError{authidentitychannel.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of AuthIdentityChannels. +func (_q *AuthIdentityChannelQuery) All(ctx context.Context) ([]*AuthIdentityChannel, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*AuthIdentityChannel, *AuthIdentityChannelQuery]() + return withInterceptors[[]*AuthIdentityChannel](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) AllX(ctx context.Context) []*AuthIdentityChannel { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of AuthIdentityChannel IDs. +func (_q *AuthIdentityChannelQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(authidentitychannel.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *AuthIdentityChannelQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityChannelQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *AuthIdentityChannelQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the AuthIdentityChannelQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *AuthIdentityChannelQuery) Clone() *AuthIdentityChannelQuery { + if _q == nil { + return nil + } + return &AuthIdentityChannelQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]authidentitychannel.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.AuthIdentityChannel{}, _q.predicates...), + withIdentity: _q.withIdentity.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithIdentity tells the query-builder to eager-load the nodes that are connected to +// the "identity" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityChannelQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *AuthIdentityChannelQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withIdentity = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.AuthIdentityChannel.Query(). +// GroupBy(authidentitychannel.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *AuthIdentityChannelQuery) GroupBy(field string, fields ...string) *AuthIdentityChannelGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &AuthIdentityChannelGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = authidentitychannel.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.AuthIdentityChannel.Query(). +// Select(authidentitychannel.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *AuthIdentityChannelQuery) Select(fields ...string) *AuthIdentityChannelSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &AuthIdentityChannelSelect{AuthIdentityChannelQuery: _q} + sbuild.label = authidentitychannel.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AuthIdentityChannelSelect configured with the given aggregations. +func (_q *AuthIdentityChannelQuery) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *AuthIdentityChannelQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !authidentitychannel.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *AuthIdentityChannelQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentityChannel, error) { + var ( + nodes = []*AuthIdentityChannel{} + _spec = _q.querySpec() + loadedTypes = [1]bool{ + _q.withIdentity != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*AuthIdentityChannel).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &AuthIdentityChannel{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withIdentity; query != nil { + if err := _q.loadIdentity(ctx, query, nodes, nil, + func(n *AuthIdentityChannel, e *AuthIdentity) { n.Edges.Identity = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *AuthIdentityChannelQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*AuthIdentityChannel, init func(*AuthIdentityChannel), assign func(*AuthIdentityChannel, *AuthIdentity)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*AuthIdentityChannel) + for i := range nodes { + fk := nodes[i].IdentityID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(authidentity.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *AuthIdentityChannelQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *AuthIdentityChannelQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID) + for i := range fields { + if fields[i] != authidentitychannel.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withIdentity != nil { + _spec.Node.AddColumnOnce(authidentitychannel.FieldIdentityID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *AuthIdentityChannelQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(authidentitychannel.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = authidentitychannel.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *AuthIdentityChannelQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityChannelQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *AuthIdentityChannelQuery) ForShare(opts ...sql.LockOption) *AuthIdentityChannelQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// AuthIdentityChannelGroupBy is the group-by builder for AuthIdentityChannel entities. +type AuthIdentityChannelGroupBy struct { + selector + build *AuthIdentityChannelQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *AuthIdentityChannelGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *AuthIdentityChannelGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *AuthIdentityChannelGroupBy) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// AuthIdentityChannelSelect is the builder for selecting fields of AuthIdentityChannel entities. +type AuthIdentityChannelSelect struct { + *AuthIdentityChannelQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *AuthIdentityChannelSelect) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *AuthIdentityChannelSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelSelect](ctx, _s.AuthIdentityChannelQuery, _s, _s.inters, v) +} + +func (_s *AuthIdentityChannelSelect) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/authidentitychannel_update.go b/backend/ent/authidentitychannel_update.go new file mode 100644 index 00000000..b550c454 --- /dev/null +++ b/backend/ent/authidentitychannel_update.go @@ -0,0 +1,581 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityChannelUpdate is the builder for updating AuthIdentityChannel entities. +type AuthIdentityChannelUpdate struct { + config + hooks []Hook + mutation *AuthIdentityChannelMutation +} + +// Where appends a list predicates to the AuthIdentityChannelUpdate builder. +func (_u *AuthIdentityChannelUpdate) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityChannelUpdate) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *AuthIdentityChannelUpdate) SetIdentityID(v int64) *AuthIdentityChannelUpdate { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityChannelUpdate) SetProviderType(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableProviderType(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityChannelUpdate) SetProviderKey(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetChannel sets the "channel" field. +func (_u *AuthIdentityChannelUpdate) SetChannel(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetChannel(v) + return _u +} + +// SetNillableChannel sets the "channel" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableChannel(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetChannel(*v) + } + return _u +} + +// SetChannelAppID sets the "channel_app_id" field. +func (_u *AuthIdentityChannelUpdate) SetChannelAppID(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetChannelAppID(v) + return _u +} + +// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetChannelAppID(*v) + } + return _u +} + +// SetChannelSubject sets the "channel_subject" field. +func (_u *AuthIdentityChannelUpdate) SetChannelSubject(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetChannelSubject(v) + return _u +} + +// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetChannelSubject(*v) + } + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityChannelUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdate { + _u.mutation.SetMetadata(v) + return _u +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdate { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the AuthIdentityChannelMutation object of the builder. +func (_u *AuthIdentityChannelUpdate) Mutation() *AuthIdentityChannelMutation { + return _u.mutation +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdate) ClearIdentity() *AuthIdentityChannelUpdate { + _u.mutation.ClearIdentity() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *AuthIdentityChannelUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *AuthIdentityChannelUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityChannelUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentitychannel.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityChannelUpdate) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentitychannel.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentitychannel.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.Channel(); ok { + if err := authidentitychannel.ChannelValidator(v); err != nil { + return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelAppID(); ok { + if err := authidentitychannel.ChannelAppIDValidator(v); err != nil { + return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelSubject(); ok { + if err := authidentitychannel.ChannelSubjectValidator(v); err != nil { + return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)} + } + } + if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`) + } + return nil +} + +func (_u *AuthIdentityChannelUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.Channel(); ok { + _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelAppID(); ok { + _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelSubject(); ok { + _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentitychannel.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// AuthIdentityChannelUpdateOne is the builder for updating a single AuthIdentityChannel entity. +type AuthIdentityChannelUpdateOne struct { + config + fields []string + hooks []Hook + mutation *AuthIdentityChannelMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityChannelUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *AuthIdentityChannelUpdateOne) SetIdentityID(v int64) *AuthIdentityChannelUpdateOne { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityChannelUpdateOne) SetProviderType(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderType(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityChannelUpdateOne) SetProviderKey(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetChannel sets the "channel" field. +func (_u *AuthIdentityChannelUpdateOne) SetChannel(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetChannel(v) + return _u +} + +// SetNillableChannel sets the "channel" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableChannel(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetChannel(*v) + } + return _u +} + +// SetChannelAppID sets the "channel_app_id" field. +func (_u *AuthIdentityChannelUpdateOne) SetChannelAppID(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetChannelAppID(v) + return _u +} + +// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetChannelAppID(*v) + } + return _u +} + +// SetChannelSubject sets the "channel_subject" field. +func (_u *AuthIdentityChannelUpdateOne) SetChannelSubject(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetChannelSubject(v) + return _u +} + +// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetChannelSubject(*v) + } + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityChannelUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdateOne { + _u.mutation.SetMetadata(v) + return _u +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdateOne) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdateOne { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the AuthIdentityChannelMutation object of the builder. +func (_u *AuthIdentityChannelUpdateOne) Mutation() *AuthIdentityChannelMutation { + return _u.mutation +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdateOne) ClearIdentity() *AuthIdentityChannelUpdateOne { + _u.mutation.ClearIdentity() + return _u +} + +// Where appends a list predicates to the AuthIdentityChannelUpdate builder. +func (_u *AuthIdentityChannelUpdateOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *AuthIdentityChannelUpdateOne) Select(field string, fields ...string) *AuthIdentityChannelUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated AuthIdentityChannel entity. +func (_u *AuthIdentityChannelUpdateOne) Save(ctx context.Context) (*AuthIdentityChannel, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdateOne) SaveX(ctx context.Context) *AuthIdentityChannel { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *AuthIdentityChannelUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityChannelUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentitychannel.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityChannelUpdateOne) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentitychannel.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentitychannel.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.Channel(); ok { + if err := authidentitychannel.ChannelValidator(v); err != nil { + return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelAppID(); ok { + if err := authidentitychannel.ChannelAppIDValidator(v); err != nil { + return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelSubject(); ok { + if err := authidentitychannel.ChannelSubjectValidator(v); err != nil { + return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)} + } + } + if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`) + } + return nil +} + +func (_u *AuthIdentityChannelUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentityChannel, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentityChannel.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID) + for _, f := range fields { + if !authidentitychannel.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != authidentitychannel.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.Channel(); ok { + _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelAppID(); ok { + _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelSubject(); ok { + _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &AuthIdentityChannel{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentitychannel.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/client.go b/backend/ent/client.go index e52e015a..b02f519b 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -20,12 +20,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -60,18 +64,26 @@ type Client struct { Announcement *AnnouncementClient // AnnouncementRead is the client for interacting with the AnnouncementRead builders. AnnouncementRead *AnnouncementReadClient + // AuthIdentity is the client for interacting with the AuthIdentity builders. + AuthIdentity *AuthIdentityClient + // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders. + AuthIdentityChannel *AuthIdentityChannelClient // ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders. ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. IdempotencyRecord *IdempotencyRecordClient + // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders. + IdentityAdoptionDecision *IdentityAdoptionDecisionClient // PaymentAuditLog is the client for interacting with the PaymentAuditLog builders. PaymentAuditLog *PaymentAuditLogClient // PaymentOrder is the client for interacting with the PaymentOrder builders. PaymentOrder *PaymentOrderClient // PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders. PaymentProviderInstance *PaymentProviderInstanceClient + // PendingAuthSession is the client for interacting with the PendingAuthSession builders. + PendingAuthSession *PendingAuthSessionClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -118,12 +130,16 @@ func (c *Client) init() { c.AccountGroup = NewAccountGroupClient(c.config) c.Announcement = NewAnnouncementClient(c.config) c.AnnouncementRead = NewAnnouncementReadClient(c.config) + c.AuthIdentity = NewAuthIdentityClient(c.config) + c.AuthIdentityChannel = NewAuthIdentityChannelClient(c.config) c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config) c.Group = NewGroupClient(c.config) c.IdempotencyRecord = NewIdempotencyRecordClient(c.config) + c.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(c.config) c.PaymentAuditLog = NewPaymentAuditLogClient(c.config) c.PaymentOrder = NewPaymentOrderClient(c.config) c.PaymentProviderInstance = NewPaymentProviderInstanceClient(c.config) + c.PendingAuthSession = NewPendingAuthSessionClient(c.config) c.PromoCode = NewPromoCodeClient(c.config) c.PromoCodeUsage = NewPromoCodeUsageClient(c.config) c.Proxy = NewProxyClient(c.config) @@ -229,34 +245,38 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { cfg := c.config cfg.driver = tx return &Tx{ - ctx: ctx, - config: cfg, - APIKey: NewAPIKeyClient(cfg), - Account: NewAccountClient(cfg), - AccountGroup: NewAccountGroupClient(cfg), - Announcement: NewAnnouncementClient(cfg), - AnnouncementRead: NewAnnouncementReadClient(cfg), - ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), - Group: NewGroupClient(cfg), - IdempotencyRecord: NewIdempotencyRecordClient(cfg), - PaymentAuditLog: NewPaymentAuditLogClient(cfg), - PaymentOrder: NewPaymentOrderClient(cfg), - PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), - PromoCode: NewPromoCodeClient(cfg), - PromoCodeUsage: NewPromoCodeUsageClient(cfg), - Proxy: NewProxyClient(cfg), - RedeemCode: NewRedeemCodeClient(cfg), - SecuritySecret: NewSecuritySecretClient(cfg), - Setting: NewSettingClient(cfg), - SubscriptionPlan: NewSubscriptionPlanClient(cfg), - TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), - UsageCleanupTask: NewUsageCleanupTaskClient(cfg), - UsageLog: NewUsageLogClient(cfg), - User: NewUserClient(cfg), - UserAllowedGroup: NewUserAllowedGroupClient(cfg), - UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), - UserAttributeValue: NewUserAttributeValueClient(cfg), - UserSubscription: NewUserSubscriptionClient(cfg), + ctx: ctx, + config: cfg, + APIKey: NewAPIKeyClient(cfg), + Account: NewAccountClient(cfg), + AccountGroup: NewAccountGroupClient(cfg), + Announcement: NewAnnouncementClient(cfg), + AnnouncementRead: NewAnnouncementReadClient(cfg), + AuthIdentity: NewAuthIdentityClient(cfg), + AuthIdentityChannel: NewAuthIdentityChannelClient(cfg), + ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), + Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), + IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg), + PaymentAuditLog: NewPaymentAuditLogClient(cfg), + PaymentOrder: NewPaymentOrderClient(cfg), + PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), + PendingAuthSession: NewPendingAuthSessionClient(cfg), + PromoCode: NewPromoCodeClient(cfg), + PromoCodeUsage: NewPromoCodeUsageClient(cfg), + Proxy: NewProxyClient(cfg), + RedeemCode: NewRedeemCodeClient(cfg), + SecuritySecret: NewSecuritySecretClient(cfg), + Setting: NewSettingClient(cfg), + SubscriptionPlan: NewSubscriptionPlanClient(cfg), + TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), + UsageCleanupTask: NewUsageCleanupTaskClient(cfg), + UsageLog: NewUsageLogClient(cfg), + User: NewUserClient(cfg), + UserAllowedGroup: NewUserAllowedGroupClient(cfg), + UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), + UserAttributeValue: NewUserAttributeValueClient(cfg), + UserSubscription: NewUserSubscriptionClient(cfg), }, nil } @@ -274,34 +294,38 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) cfg := c.config cfg.driver = &txDriver{tx: tx, drv: c.driver} return &Tx{ - ctx: ctx, - config: cfg, - APIKey: NewAPIKeyClient(cfg), - Account: NewAccountClient(cfg), - AccountGroup: NewAccountGroupClient(cfg), - Announcement: NewAnnouncementClient(cfg), - AnnouncementRead: NewAnnouncementReadClient(cfg), - ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), - Group: NewGroupClient(cfg), - IdempotencyRecord: NewIdempotencyRecordClient(cfg), - PaymentAuditLog: NewPaymentAuditLogClient(cfg), - PaymentOrder: NewPaymentOrderClient(cfg), - PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), - PromoCode: NewPromoCodeClient(cfg), - PromoCodeUsage: NewPromoCodeUsageClient(cfg), - Proxy: NewProxyClient(cfg), - RedeemCode: NewRedeemCodeClient(cfg), - SecuritySecret: NewSecuritySecretClient(cfg), - Setting: NewSettingClient(cfg), - SubscriptionPlan: NewSubscriptionPlanClient(cfg), - TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), - UsageCleanupTask: NewUsageCleanupTaskClient(cfg), - UsageLog: NewUsageLogClient(cfg), - User: NewUserClient(cfg), - UserAllowedGroup: NewUserAllowedGroupClient(cfg), - UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), - UserAttributeValue: NewUserAttributeValueClient(cfg), - UserSubscription: NewUserSubscriptionClient(cfg), + ctx: ctx, + config: cfg, + APIKey: NewAPIKeyClient(cfg), + Account: NewAccountClient(cfg), + AccountGroup: NewAccountGroupClient(cfg), + Announcement: NewAnnouncementClient(cfg), + AnnouncementRead: NewAnnouncementReadClient(cfg), + AuthIdentity: NewAuthIdentityClient(cfg), + AuthIdentityChannel: NewAuthIdentityChannelClient(cfg), + ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), + Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), + IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg), + PaymentAuditLog: NewPaymentAuditLogClient(cfg), + PaymentOrder: NewPaymentOrderClient(cfg), + PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), + PendingAuthSession: NewPendingAuthSessionClient(cfg), + PromoCode: NewPromoCodeClient(cfg), + PromoCodeUsage: NewPromoCodeUsageClient(cfg), + Proxy: NewProxyClient(cfg), + RedeemCode: NewRedeemCodeClient(cfg), + SecuritySecret: NewSecuritySecretClient(cfg), + Setting: NewSettingClient(cfg), + SubscriptionPlan: NewSubscriptionPlanClient(cfg), + TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), + UsageCleanupTask: NewUsageCleanupTaskClient(cfg), + UsageLog: NewUsageLogClient(cfg), + User: NewUserClient(cfg), + UserAllowedGroup: NewUserAllowedGroupClient(cfg), + UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), + UserAttributeValue: NewUserAttributeValueClient(cfg), + UserSubscription: NewUserSubscriptionClient(cfg), }, nil } @@ -332,11 +356,12 @@ func (c *Client) Close() error { func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, - c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, - c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.AuthIdentity, c.AuthIdentityChannel, c.ErrorPassthroughRule, c.Group, + c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog, + c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, + c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Use(hooks...) @@ -348,11 +373,12 @@ func (c *Client) Use(hooks ...Hook) { func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, - c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, - c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.AuthIdentity, c.AuthIdentityChannel, c.ErrorPassthroughRule, c.Group, + c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog, + c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, + c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Intercept(interceptors...) @@ -372,18 +398,26 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.Announcement.mutate(ctx, m) case *AnnouncementReadMutation: return c.AnnouncementRead.mutate(ctx, m) + case *AuthIdentityMutation: + return c.AuthIdentity.mutate(ctx, m) + case *AuthIdentityChannelMutation: + return c.AuthIdentityChannel.mutate(ctx, m) case *ErrorPassthroughRuleMutation: return c.ErrorPassthroughRule.mutate(ctx, m) case *GroupMutation: return c.Group.mutate(ctx, m) case *IdempotencyRecordMutation: return c.IdempotencyRecord.mutate(ctx, m) + case *IdentityAdoptionDecisionMutation: + return c.IdentityAdoptionDecision.mutate(ctx, m) case *PaymentAuditLogMutation: return c.PaymentAuditLog.mutate(ctx, m) case *PaymentOrderMutation: return c.PaymentOrder.mutate(ctx, m) case *PaymentProviderInstanceMutation: return c.PaymentProviderInstance.mutate(ctx, m) + case *PendingAuthSessionMutation: + return c.PendingAuthSession.mutate(ctx, m) case *PromoCodeMutation: return c.PromoCode.mutate(ctx, m) case *PromoCodeUsageMutation: @@ -1231,6 +1265,336 @@ func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementRead } } +// AuthIdentityClient is a client for the AuthIdentity schema. +type AuthIdentityClient struct { + config +} + +// NewAuthIdentityClient returns a client for the AuthIdentity from the given config. +func NewAuthIdentityClient(c config) *AuthIdentityClient { + return &AuthIdentityClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `authidentity.Hooks(f(g(h())))`. +func (c *AuthIdentityClient) Use(hooks ...Hook) { + c.hooks.AuthIdentity = append(c.hooks.AuthIdentity, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `authidentity.Intercept(f(g(h())))`. +func (c *AuthIdentityClient) Intercept(interceptors ...Interceptor) { + c.inters.AuthIdentity = append(c.inters.AuthIdentity, interceptors...) +} + +// Create returns a builder for creating a AuthIdentity entity. +func (c *AuthIdentityClient) Create() *AuthIdentityCreate { + mutation := newAuthIdentityMutation(c.config, OpCreate) + return &AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of AuthIdentity entities. +func (c *AuthIdentityClient) CreateBulk(builders ...*AuthIdentityCreate) *AuthIdentityCreateBulk { + return &AuthIdentityCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AuthIdentityClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityCreate, int)) *AuthIdentityCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AuthIdentityCreateBulk{err: fmt.Errorf("calling to AuthIdentityClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AuthIdentityCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AuthIdentityCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for AuthIdentity. +func (c *AuthIdentityClient) Update() *AuthIdentityUpdate { + mutation := newAuthIdentityMutation(c.config, OpUpdate) + return &AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *AuthIdentityClient) UpdateOne(_m *AuthIdentity) *AuthIdentityUpdateOne { + mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentity(_m)) + return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *AuthIdentityClient) UpdateOneID(id int64) *AuthIdentityUpdateOne { + mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentityID(id)) + return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for AuthIdentity. +func (c *AuthIdentityClient) Delete() *AuthIdentityDelete { + mutation := newAuthIdentityMutation(c.config, OpDelete) + return &AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *AuthIdentityClient) DeleteOne(_m *AuthIdentity) *AuthIdentityDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *AuthIdentityClient) DeleteOneID(id int64) *AuthIdentityDeleteOne { + builder := c.Delete().Where(authidentity.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &AuthIdentityDeleteOne{builder} +} + +// Query returns a query builder for AuthIdentity. +func (c *AuthIdentityClient) Query() *AuthIdentityQuery { + return &AuthIdentityQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeAuthIdentity}, + inters: c.Interceptors(), + } +} + +// Get returns a AuthIdentity entity by its id. +func (c *AuthIdentityClient) Get(ctx context.Context, id int64) (*AuthIdentity, error) { + return c.Query().Where(authidentity.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *AuthIdentityClient) GetX(ctx context.Context, id int64) *AuthIdentity { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryUser queries the user edge of a AuthIdentity. +func (c *AuthIdentityClient) QueryUser(_m *AuthIdentity) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryChannels queries the channels edge of a AuthIdentity. +func (c *AuthIdentityClient) QueryChannels(_m *AuthIdentity) *AuthIdentityChannelQuery { + query := (&AuthIdentityChannelClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, id), + sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryAdoptionDecisions queries the adoption_decisions edge of a AuthIdentity. +func (c *AuthIdentityClient) QueryAdoptionDecisions(_m *AuthIdentity) *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, id), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *AuthIdentityClient) Hooks() []Hook { + return c.hooks.AuthIdentity +} + +// Interceptors returns the client interceptors. +func (c *AuthIdentityClient) Interceptors() []Interceptor { + return c.inters.AuthIdentity +} + +func (c *AuthIdentityClient) mutate(ctx context.Context, m *AuthIdentityMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown AuthIdentity mutation op: %q", m.Op()) + } +} + +// AuthIdentityChannelClient is a client for the AuthIdentityChannel schema. +type AuthIdentityChannelClient struct { + config +} + +// NewAuthIdentityChannelClient returns a client for the AuthIdentityChannel from the given config. +func NewAuthIdentityChannelClient(c config) *AuthIdentityChannelClient { + return &AuthIdentityChannelClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `authidentitychannel.Hooks(f(g(h())))`. +func (c *AuthIdentityChannelClient) Use(hooks ...Hook) { + c.hooks.AuthIdentityChannel = append(c.hooks.AuthIdentityChannel, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `authidentitychannel.Intercept(f(g(h())))`. +func (c *AuthIdentityChannelClient) Intercept(interceptors ...Interceptor) { + c.inters.AuthIdentityChannel = append(c.inters.AuthIdentityChannel, interceptors...) +} + +// Create returns a builder for creating a AuthIdentityChannel entity. +func (c *AuthIdentityChannelClient) Create() *AuthIdentityChannelCreate { + mutation := newAuthIdentityChannelMutation(c.config, OpCreate) + return &AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of AuthIdentityChannel entities. +func (c *AuthIdentityChannelClient) CreateBulk(builders ...*AuthIdentityChannelCreate) *AuthIdentityChannelCreateBulk { + return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AuthIdentityChannelClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityChannelCreate, int)) *AuthIdentityChannelCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AuthIdentityChannelCreateBulk{err: fmt.Errorf("calling to AuthIdentityChannelClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AuthIdentityChannelCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for AuthIdentityChannel. +func (c *AuthIdentityChannelClient) Update() *AuthIdentityChannelUpdate { + mutation := newAuthIdentityChannelMutation(c.config, OpUpdate) + return &AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *AuthIdentityChannelClient) UpdateOne(_m *AuthIdentityChannel) *AuthIdentityChannelUpdateOne { + mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannel(_m)) + return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *AuthIdentityChannelClient) UpdateOneID(id int64) *AuthIdentityChannelUpdateOne { + mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannelID(id)) + return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for AuthIdentityChannel. +func (c *AuthIdentityChannelClient) Delete() *AuthIdentityChannelDelete { + mutation := newAuthIdentityChannelMutation(c.config, OpDelete) + return &AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *AuthIdentityChannelClient) DeleteOne(_m *AuthIdentityChannel) *AuthIdentityChannelDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *AuthIdentityChannelClient) DeleteOneID(id int64) *AuthIdentityChannelDeleteOne { + builder := c.Delete().Where(authidentitychannel.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &AuthIdentityChannelDeleteOne{builder} +} + +// Query returns a query builder for AuthIdentityChannel. +func (c *AuthIdentityChannelClient) Query() *AuthIdentityChannelQuery { + return &AuthIdentityChannelQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeAuthIdentityChannel}, + inters: c.Interceptors(), + } +} + +// Get returns a AuthIdentityChannel entity by its id. +func (c *AuthIdentityChannelClient) Get(ctx context.Context, id int64) (*AuthIdentityChannel, error) { + return c.Query().Where(authidentitychannel.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *AuthIdentityChannelClient) GetX(ctx context.Context, id int64) *AuthIdentityChannel { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryIdentity queries the identity edge of a AuthIdentityChannel. +func (c *AuthIdentityChannelClient) QueryIdentity(_m *AuthIdentityChannel) *AuthIdentityQuery { + query := (&AuthIdentityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, id), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *AuthIdentityChannelClient) Hooks() []Hook { + return c.hooks.AuthIdentityChannel +} + +// Interceptors returns the client interceptors. +func (c *AuthIdentityChannelClient) Interceptors() []Interceptor { + return c.inters.AuthIdentityChannel +} + +func (c *AuthIdentityChannelClient) mutate(ctx context.Context, m *AuthIdentityChannelMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown AuthIdentityChannel mutation op: %q", m.Op()) + } +} + // ErrorPassthroughRuleClient is a client for the ErrorPassthroughRule schema. type ErrorPassthroughRuleClient struct { config @@ -1760,6 +2124,171 @@ func (c *IdempotencyRecordClient) mutate(ctx context.Context, m *IdempotencyReco } } +// IdentityAdoptionDecisionClient is a client for the IdentityAdoptionDecision schema. +type IdentityAdoptionDecisionClient struct { + config +} + +// NewIdentityAdoptionDecisionClient returns a client for the IdentityAdoptionDecision from the given config. +func NewIdentityAdoptionDecisionClient(c config) *IdentityAdoptionDecisionClient { + return &IdentityAdoptionDecisionClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `identityadoptiondecision.Hooks(f(g(h())))`. +func (c *IdentityAdoptionDecisionClient) Use(hooks ...Hook) { + c.hooks.IdentityAdoptionDecision = append(c.hooks.IdentityAdoptionDecision, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `identityadoptiondecision.Intercept(f(g(h())))`. +func (c *IdentityAdoptionDecisionClient) Intercept(interceptors ...Interceptor) { + c.inters.IdentityAdoptionDecision = append(c.inters.IdentityAdoptionDecision, interceptors...) +} + +// Create returns a builder for creating a IdentityAdoptionDecision entity. +func (c *IdentityAdoptionDecisionClient) Create() *IdentityAdoptionDecisionCreate { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpCreate) + return &IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of IdentityAdoptionDecision entities. +func (c *IdentityAdoptionDecisionClient) CreateBulk(builders ...*IdentityAdoptionDecisionCreate) *IdentityAdoptionDecisionCreateBulk { + return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *IdentityAdoptionDecisionClient) MapCreateBulk(slice any, setFunc func(*IdentityAdoptionDecisionCreate, int)) *IdentityAdoptionDecisionCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &IdentityAdoptionDecisionCreateBulk{err: fmt.Errorf("calling to IdentityAdoptionDecisionClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*IdentityAdoptionDecisionCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) Update() *IdentityAdoptionDecisionUpdate { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdate) + return &IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *IdentityAdoptionDecisionClient) UpdateOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecision(_m)) + return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *IdentityAdoptionDecisionClient) UpdateOneID(id int64) *IdentityAdoptionDecisionUpdateOne { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecisionID(id)) + return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) Delete() *IdentityAdoptionDecisionDelete { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpDelete) + return &IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *IdentityAdoptionDecisionClient) DeleteOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *IdentityAdoptionDecisionClient) DeleteOneID(id int64) *IdentityAdoptionDecisionDeleteOne { + builder := c.Delete().Where(identityadoptiondecision.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &IdentityAdoptionDecisionDeleteOne{builder} +} + +// Query returns a query builder for IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) Query() *IdentityAdoptionDecisionQuery { + return &IdentityAdoptionDecisionQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeIdentityAdoptionDecision}, + inters: c.Interceptors(), + } +} + +// Get returns a IdentityAdoptionDecision entity by its id. +func (c *IdentityAdoptionDecisionClient) Get(ctx context.Context, id int64) (*IdentityAdoptionDecision, error) { + return c.Query().Where(identityadoptiondecision.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *IdentityAdoptionDecisionClient) GetX(ctx context.Context, id int64) *IdentityAdoptionDecision { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryPendingAuthSession queries the pending_auth_session edge of a IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) QueryPendingAuthSession(_m *IdentityAdoptionDecision) *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryIdentity queries the identity edge of a IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) QueryIdentity(_m *IdentityAdoptionDecision) *AuthIdentityQuery { + query := (&AuthIdentityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *IdentityAdoptionDecisionClient) Hooks() []Hook { + return c.hooks.IdentityAdoptionDecision +} + +// Interceptors returns the client interceptors. +func (c *IdentityAdoptionDecisionClient) Interceptors() []Interceptor { + return c.inters.IdentityAdoptionDecision +} + +func (c *IdentityAdoptionDecisionClient) mutate(ctx context.Context, m *IdentityAdoptionDecisionMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown IdentityAdoptionDecision mutation op: %q", m.Op()) + } +} + // PaymentAuditLogClient is a client for the PaymentAuditLog schema. type PaymentAuditLogClient struct { config @@ -2175,6 +2704,171 @@ func (c *PaymentProviderInstanceClient) mutate(ctx context.Context, m *PaymentPr } } +// PendingAuthSessionClient is a client for the PendingAuthSession schema. +type PendingAuthSessionClient struct { + config +} + +// NewPendingAuthSessionClient returns a client for the PendingAuthSession from the given config. +func NewPendingAuthSessionClient(c config) *PendingAuthSessionClient { + return &PendingAuthSessionClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `pendingauthsession.Hooks(f(g(h())))`. +func (c *PendingAuthSessionClient) Use(hooks ...Hook) { + c.hooks.PendingAuthSession = append(c.hooks.PendingAuthSession, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `pendingauthsession.Intercept(f(g(h())))`. +func (c *PendingAuthSessionClient) Intercept(interceptors ...Interceptor) { + c.inters.PendingAuthSession = append(c.inters.PendingAuthSession, interceptors...) +} + +// Create returns a builder for creating a PendingAuthSession entity. +func (c *PendingAuthSessionClient) Create() *PendingAuthSessionCreate { + mutation := newPendingAuthSessionMutation(c.config, OpCreate) + return &PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of PendingAuthSession entities. +func (c *PendingAuthSessionClient) CreateBulk(builders ...*PendingAuthSessionCreate) *PendingAuthSessionCreateBulk { + return &PendingAuthSessionCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *PendingAuthSessionClient) MapCreateBulk(slice any, setFunc func(*PendingAuthSessionCreate, int)) *PendingAuthSessionCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &PendingAuthSessionCreateBulk{err: fmt.Errorf("calling to PendingAuthSessionClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*PendingAuthSessionCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &PendingAuthSessionCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for PendingAuthSession. +func (c *PendingAuthSessionClient) Update() *PendingAuthSessionUpdate { + mutation := newPendingAuthSessionMutation(c.config, OpUpdate) + return &PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *PendingAuthSessionClient) UpdateOne(_m *PendingAuthSession) *PendingAuthSessionUpdateOne { + mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSession(_m)) + return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *PendingAuthSessionClient) UpdateOneID(id int64) *PendingAuthSessionUpdateOne { + mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSessionID(id)) + return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for PendingAuthSession. +func (c *PendingAuthSessionClient) Delete() *PendingAuthSessionDelete { + mutation := newPendingAuthSessionMutation(c.config, OpDelete) + return &PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *PendingAuthSessionClient) DeleteOne(_m *PendingAuthSession) *PendingAuthSessionDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *PendingAuthSessionClient) DeleteOneID(id int64) *PendingAuthSessionDeleteOne { + builder := c.Delete().Where(pendingauthsession.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &PendingAuthSessionDeleteOne{builder} +} + +// Query returns a query builder for PendingAuthSession. +func (c *PendingAuthSessionClient) Query() *PendingAuthSessionQuery { + return &PendingAuthSessionQuery{ + config: c.config, + ctx: &QueryContext{Type: TypePendingAuthSession}, + inters: c.Interceptors(), + } +} + +// Get returns a PendingAuthSession entity by its id. +func (c *PendingAuthSessionClient) Get(ctx context.Context, id int64) (*PendingAuthSession, error) { + return c.Query().Where(pendingauthsession.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *PendingAuthSessionClient) GetX(ctx context.Context, id int64) *PendingAuthSession { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryTargetUser queries the target_user edge of a PendingAuthSession. +func (c *PendingAuthSessionClient) QueryTargetUser(_m *PendingAuthSession) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryAdoptionDecision queries the adoption_decision edge of a PendingAuthSession. +func (c *PendingAuthSessionClient) QueryAdoptionDecision(_m *PendingAuthSession) *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *PendingAuthSessionClient) Hooks() []Hook { + return c.hooks.PendingAuthSession +} + +// Interceptors returns the client interceptors. +func (c *PendingAuthSessionClient) Interceptors() []Interceptor { + return c.inters.PendingAuthSession +} + +func (c *PendingAuthSessionClient) mutate(ctx context.Context, m *PendingAuthSessionMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown PendingAuthSession mutation op: %q", m.Op()) + } +} + // PromoCodeClient is a client for the PromoCode schema. type PromoCodeClient struct { config @@ -3951,6 +4645,38 @@ func (c *UserClient) QueryPaymentOrders(_m *User) *PaymentOrderQuery { return query } +// QueryAuthIdentities queries the auth_identities edge of a User. +func (c *UserClient) QueryAuthIdentities(_m *User) *AuthIdentityQuery { + query := (&AuthIdentityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryPendingAuthSessions queries the pending_auth_sessions edge of a User. +func (c *UserClient) QueryPendingAuthSessions(_m *User) *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // QueryUserAllowedGroups queries the user_allowed_groups edge of a User. func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: c.config}).Query() @@ -4628,18 +5354,20 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription // hooks and interceptors per client, for fast access. type ( hooks struct { - APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, - PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, + APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity, + AuthIdentityChannel, ErrorPassthroughRule, Group, IdempotencyRecord, + IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder, + PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy, + RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook } inters struct { - APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, - PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, + APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity, + AuthIdentityChannel, ErrorPassthroughRule, Group, IdempotencyRecord, + IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder, + PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy, + RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor } diff --git a/backend/ent/ent.go b/backend/ent/ent.go index 96ed5e03..339e5369 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -17,12 +17,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -98,32 +102,36 @@ var ( func checkColumn(t, c string) error { initCheck.Do(func() { columnCheck = sql.NewColumnCheck(map[string]func(string) bool{ - apikey.Table: apikey.ValidColumn, - account.Table: account.ValidColumn, - accountgroup.Table: accountgroup.ValidColumn, - announcement.Table: announcement.ValidColumn, - announcementread.Table: announcementread.ValidColumn, - errorpassthroughrule.Table: errorpassthroughrule.ValidColumn, - group.Table: group.ValidColumn, - idempotencyrecord.Table: idempotencyrecord.ValidColumn, - paymentauditlog.Table: paymentauditlog.ValidColumn, - paymentorder.Table: paymentorder.ValidColumn, - paymentproviderinstance.Table: paymentproviderinstance.ValidColumn, - promocode.Table: promocode.ValidColumn, - promocodeusage.Table: promocodeusage.ValidColumn, - proxy.Table: proxy.ValidColumn, - redeemcode.Table: redeemcode.ValidColumn, - securitysecret.Table: securitysecret.ValidColumn, - setting.Table: setting.ValidColumn, - subscriptionplan.Table: subscriptionplan.ValidColumn, - tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn, - usagecleanuptask.Table: usagecleanuptask.ValidColumn, - usagelog.Table: usagelog.ValidColumn, - user.Table: user.ValidColumn, - userallowedgroup.Table: userallowedgroup.ValidColumn, - userattributedefinition.Table: userattributedefinition.ValidColumn, - userattributevalue.Table: userattributevalue.ValidColumn, - usersubscription.Table: usersubscription.ValidColumn, + apikey.Table: apikey.ValidColumn, + account.Table: account.ValidColumn, + accountgroup.Table: accountgroup.ValidColumn, + announcement.Table: announcement.ValidColumn, + announcementread.Table: announcementread.ValidColumn, + authidentity.Table: authidentity.ValidColumn, + authidentitychannel.Table: authidentitychannel.ValidColumn, + errorpassthroughrule.Table: errorpassthroughrule.ValidColumn, + group.Table: group.ValidColumn, + idempotencyrecord.Table: idempotencyrecord.ValidColumn, + identityadoptiondecision.Table: identityadoptiondecision.ValidColumn, + paymentauditlog.Table: paymentauditlog.ValidColumn, + paymentorder.Table: paymentorder.ValidColumn, + paymentproviderinstance.Table: paymentproviderinstance.ValidColumn, + pendingauthsession.Table: pendingauthsession.ValidColumn, + promocode.Table: promocode.ValidColumn, + promocodeusage.Table: promocodeusage.ValidColumn, + proxy.Table: proxy.ValidColumn, + redeemcode.Table: redeemcode.ValidColumn, + securitysecret.Table: securitysecret.ValidColumn, + setting.Table: setting.ValidColumn, + subscriptionplan.Table: subscriptionplan.ValidColumn, + tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn, + usagecleanuptask.Table: usagecleanuptask.ValidColumn, + usagelog.Table: usagelog.ValidColumn, + user.Table: user.ValidColumn, + userallowedgroup.Table: userallowedgroup.ValidColumn, + userattributedefinition.Table: userattributedefinition.ValidColumn, + userattributevalue.Table: userattributevalue.ValidColumn, + usersubscription.Table: usersubscription.ValidColumn, }) }) return columnCheck(t, c) diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index 199dacea..46ac02bc 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -69,6 +69,30 @@ func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m) } +// The AuthIdentityFunc type is an adapter to allow the use of ordinary +// function as AuthIdentity mutator. +type AuthIdentityFunc func(context.Context, *ent.AuthIdentityMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f AuthIdentityFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.AuthIdentityMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityMutation", m) +} + +// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary +// function as AuthIdentityChannel mutator. +type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f AuthIdentityChannelFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.AuthIdentityChannelMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityChannelMutation", m) +} + // The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary // function as ErrorPassthroughRule mutator. type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleMutation) (ent.Value, error) @@ -105,6 +129,18 @@ func (f IdempotencyRecordFunc) Mutate(ctx context.Context, m ent.Mutation) (ent. return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdempotencyRecordMutation", m) } +// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary +// function as IdentityAdoptionDecision mutator. +type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f IdentityAdoptionDecisionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.IdentityAdoptionDecisionMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdentityAdoptionDecisionMutation", m) +} + // The PaymentAuditLogFunc type is an adapter to allow the use of ordinary // function as PaymentAuditLog mutator. type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogMutation) (ent.Value, error) @@ -141,6 +177,18 @@ func (f PaymentProviderInstanceFunc) Mutate(ctx context.Context, m ent.Mutation) return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PaymentProviderInstanceMutation", m) } +// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary +// function as PendingAuthSession mutator. +type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f PendingAuthSessionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.PendingAuthSessionMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PendingAuthSessionMutation", m) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary // function as PromoCode mutator. type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error) diff --git a/backend/ent/identityadoptiondecision.go b/backend/ent/identityadoptiondecision.go new file mode 100644 index 00000000..ecaee65c --- /dev/null +++ b/backend/ent/identityadoptiondecision.go @@ -0,0 +1,223 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" +) + +// IdentityAdoptionDecision is the model entity for the IdentityAdoptionDecision schema. +type IdentityAdoptionDecision struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // PendingAuthSessionID holds the value of the "pending_auth_session_id" field. + PendingAuthSessionID int64 `json:"pending_auth_session_id,omitempty"` + // IdentityID holds the value of the "identity_id" field. + IdentityID *int64 `json:"identity_id,omitempty"` + // AdoptDisplayName holds the value of the "adopt_display_name" field. + AdoptDisplayName bool `json:"adopt_display_name,omitempty"` + // AdoptAvatar holds the value of the "adopt_avatar" field. + AdoptAvatar bool `json:"adopt_avatar,omitempty"` + // DecidedAt holds the value of the "decided_at" field. + DecidedAt time.Time `json:"decided_at,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the IdentityAdoptionDecisionQuery when eager-loading is set. + Edges IdentityAdoptionDecisionEdges `json:"edges"` + selectValues sql.SelectValues +} + +// IdentityAdoptionDecisionEdges holds the relations/edges for other nodes in the graph. +type IdentityAdoptionDecisionEdges struct { + // PendingAuthSession holds the value of the pending_auth_session edge. + PendingAuthSession *PendingAuthSession `json:"pending_auth_session,omitempty"` + // Identity holds the value of the identity edge. + Identity *AuthIdentity `json:"identity,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [2]bool +} + +// PendingAuthSessionOrErr returns the PendingAuthSession value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e IdentityAdoptionDecisionEdges) PendingAuthSessionOrErr() (*PendingAuthSession, error) { + if e.PendingAuthSession != nil { + return e.PendingAuthSession, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: pendingauthsession.Label} + } + return nil, &NotLoadedError{edge: "pending_auth_session"} +} + +// IdentityOrErr returns the Identity value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e IdentityAdoptionDecisionEdges) IdentityOrErr() (*AuthIdentity, error) { + if e.Identity != nil { + return e.Identity, nil + } else if e.loadedTypes[1] { + return nil, &NotFoundError{label: authidentity.Label} + } + return nil, &NotLoadedError{edge: "identity"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*IdentityAdoptionDecision) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case identityadoptiondecision.FieldAdoptDisplayName, identityadoptiondecision.FieldAdoptAvatar: + values[i] = new(sql.NullBool) + case identityadoptiondecision.FieldID, identityadoptiondecision.FieldPendingAuthSessionID, identityadoptiondecision.FieldIdentityID: + values[i] = new(sql.NullInt64) + case identityadoptiondecision.FieldCreatedAt, identityadoptiondecision.FieldUpdatedAt, identityadoptiondecision.FieldDecidedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the IdentityAdoptionDecision fields. +func (_m *IdentityAdoptionDecision) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case identityadoptiondecision.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case identityadoptiondecision.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case identityadoptiondecision.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case identityadoptiondecision.FieldPendingAuthSessionID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field pending_auth_session_id", values[i]) + } else if value.Valid { + _m.PendingAuthSessionID = value.Int64 + } + case identityadoptiondecision.FieldIdentityID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field identity_id", values[i]) + } else if value.Valid { + _m.IdentityID = new(int64) + *_m.IdentityID = value.Int64 + } + case identityadoptiondecision.FieldAdoptDisplayName: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field adopt_display_name", values[i]) + } else if value.Valid { + _m.AdoptDisplayName = value.Bool + } + case identityadoptiondecision.FieldAdoptAvatar: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field adopt_avatar", values[i]) + } else if value.Valid { + _m.AdoptAvatar = value.Bool + } + case identityadoptiondecision.FieldDecidedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field decided_at", values[i]) + } else if value.Valid { + _m.DecidedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the IdentityAdoptionDecision. +// This includes values selected through modifiers, order, etc. +func (_m *IdentityAdoptionDecision) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryPendingAuthSession queries the "pending_auth_session" edge of the IdentityAdoptionDecision entity. +func (_m *IdentityAdoptionDecision) QueryPendingAuthSession() *PendingAuthSessionQuery { + return NewIdentityAdoptionDecisionClient(_m.config).QueryPendingAuthSession(_m) +} + +// QueryIdentity queries the "identity" edge of the IdentityAdoptionDecision entity. +func (_m *IdentityAdoptionDecision) QueryIdentity() *AuthIdentityQuery { + return NewIdentityAdoptionDecisionClient(_m.config).QueryIdentity(_m) +} + +// Update returns a builder for updating this IdentityAdoptionDecision. +// Note that you need to call IdentityAdoptionDecision.Unwrap() before calling this method if this IdentityAdoptionDecision +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *IdentityAdoptionDecision) Update() *IdentityAdoptionDecisionUpdateOne { + return NewIdentityAdoptionDecisionClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the IdentityAdoptionDecision entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *IdentityAdoptionDecision) Unwrap() *IdentityAdoptionDecision { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: IdentityAdoptionDecision is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *IdentityAdoptionDecision) String() string { + var builder strings.Builder + builder.WriteString("IdentityAdoptionDecision(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("pending_auth_session_id=") + builder.WriteString(fmt.Sprintf("%v", _m.PendingAuthSessionID)) + builder.WriteString(", ") + if v := _m.IdentityID; v != nil { + builder.WriteString("identity_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("adopt_display_name=") + builder.WriteString(fmt.Sprintf("%v", _m.AdoptDisplayName)) + builder.WriteString(", ") + builder.WriteString("adopt_avatar=") + builder.WriteString(fmt.Sprintf("%v", _m.AdoptAvatar)) + builder.WriteString(", ") + builder.WriteString("decided_at=") + builder.WriteString(_m.DecidedAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// IdentityAdoptionDecisions is a parsable slice of IdentityAdoptionDecision. +type IdentityAdoptionDecisions []*IdentityAdoptionDecision diff --git a/backend/ent/identityadoptiondecision/identityadoptiondecision.go b/backend/ent/identityadoptiondecision/identityadoptiondecision.go new file mode 100644 index 00000000..93adaf73 --- /dev/null +++ b/backend/ent/identityadoptiondecision/identityadoptiondecision.go @@ -0,0 +1,159 @@ +// Code generated by ent, DO NOT EDIT. + +package identityadoptiondecision + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the identityadoptiondecision type in the database. + Label = "identity_adoption_decision" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldPendingAuthSessionID holds the string denoting the pending_auth_session_id field in the database. + FieldPendingAuthSessionID = "pending_auth_session_id" + // FieldIdentityID holds the string denoting the identity_id field in the database. + FieldIdentityID = "identity_id" + // FieldAdoptDisplayName holds the string denoting the adopt_display_name field in the database. + FieldAdoptDisplayName = "adopt_display_name" + // FieldAdoptAvatar holds the string denoting the adopt_avatar field in the database. + FieldAdoptAvatar = "adopt_avatar" + // FieldDecidedAt holds the string denoting the decided_at field in the database. + FieldDecidedAt = "decided_at" + // EdgePendingAuthSession holds the string denoting the pending_auth_session edge name in mutations. + EdgePendingAuthSession = "pending_auth_session" + // EdgeIdentity holds the string denoting the identity edge name in mutations. + EdgeIdentity = "identity" + // Table holds the table name of the identityadoptiondecision in the database. + Table = "identity_adoption_decisions" + // PendingAuthSessionTable is the table that holds the pending_auth_session relation/edge. + PendingAuthSessionTable = "identity_adoption_decisions" + // PendingAuthSessionInverseTable is the table name for the PendingAuthSession entity. + // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package. + PendingAuthSessionInverseTable = "pending_auth_sessions" + // PendingAuthSessionColumn is the table column denoting the pending_auth_session relation/edge. + PendingAuthSessionColumn = "pending_auth_session_id" + // IdentityTable is the table that holds the identity relation/edge. + IdentityTable = "identity_adoption_decisions" + // IdentityInverseTable is the table name for the AuthIdentity entity. + // It exists in this package in order to avoid circular dependency with the "authidentity" package. + IdentityInverseTable = "auth_identities" + // IdentityColumn is the table column denoting the identity relation/edge. + IdentityColumn = "identity_id" +) + +// Columns holds all SQL columns for identityadoptiondecision fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldPendingAuthSessionID, + FieldIdentityID, + FieldAdoptDisplayName, + FieldAdoptAvatar, + FieldDecidedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // DefaultAdoptDisplayName holds the default value on creation for the "adopt_display_name" field. + DefaultAdoptDisplayName bool + // DefaultAdoptAvatar holds the default value on creation for the "adopt_avatar" field. + DefaultAdoptAvatar bool + // DefaultDecidedAt holds the default value on creation for the "decided_at" field. + DefaultDecidedAt func() time.Time +) + +// OrderOption defines the ordering options for the IdentityAdoptionDecision queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByPendingAuthSessionID orders the results by the pending_auth_session_id field. +func ByPendingAuthSessionID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPendingAuthSessionID, opts...).ToFunc() +} + +// ByIdentityID orders the results by the identity_id field. +func ByIdentityID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIdentityID, opts...).ToFunc() +} + +// ByAdoptDisplayName orders the results by the adopt_display_name field. +func ByAdoptDisplayName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAdoptDisplayName, opts...).ToFunc() +} + +// ByAdoptAvatar orders the results by the adopt_avatar field. +func ByAdoptAvatar(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAdoptAvatar, opts...).ToFunc() +} + +// ByDecidedAt orders the results by the decided_at field. +func ByDecidedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDecidedAt, opts...).ToFunc() +} + +// ByPendingAuthSessionField orders the results by pending_auth_session field. +func ByPendingAuthSessionField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionStep(), sql.OrderByField(field, opts...)) + } +} + +// ByIdentityField orders the results by identity field. +func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...)) + } +} +func newPendingAuthSessionStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(PendingAuthSessionInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn), + ) +} +func newIdentityStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(IdentityInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) +} diff --git a/backend/ent/identityadoptiondecision/where.go b/backend/ent/identityadoptiondecision/where.go new file mode 100644 index 00000000..1968f175 --- /dev/null +++ b/backend/ent/identityadoptiondecision/where.go @@ -0,0 +1,342 @@ +// Code generated by ent, DO NOT EDIT. + +package identityadoptiondecision + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// PendingAuthSessionID applies equality check predicate on the "pending_auth_session_id" field. It's identical to PendingAuthSessionIDEQ. +func PendingAuthSessionID(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v)) +} + +// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ. +func IdentityID(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v)) +} + +// AdoptDisplayName applies equality check predicate on the "adopt_display_name" field. It's identical to AdoptDisplayNameEQ. +func AdoptDisplayName(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v)) +} + +// AdoptAvatar applies equality check predicate on the "adopt_avatar" field. It's identical to AdoptAvatarEQ. +func AdoptAvatar(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v)) +} + +// DecidedAt applies equality check predicate on the "decided_at" field. It's identical to DecidedAtEQ. +func DecidedAt(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// PendingAuthSessionIDEQ applies the EQ predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v)) +} + +// PendingAuthSessionIDNEQ applies the NEQ predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDNEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldPendingAuthSessionID, v)) +} + +// PendingAuthSessionIDIn applies the In predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldPendingAuthSessionID, vs...)) +} + +// PendingAuthSessionIDNotIn applies the NotIn predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldPendingAuthSessionID, vs...)) +} + +// IdentityIDEQ applies the EQ predicate on the "identity_id" field. +func IdentityIDEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v)) +} + +// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field. +func IdentityIDNEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldIdentityID, v)) +} + +// IdentityIDIn applies the In predicate on the "identity_id" field. +func IdentityIDIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldIdentityID, vs...)) +} + +// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field. +func IdentityIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldIdentityID, vs...)) +} + +// IdentityIDIsNil applies the IsNil predicate on the "identity_id" field. +func IdentityIDIsNil() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIsNull(FieldIdentityID)) +} + +// IdentityIDNotNil applies the NotNil predicate on the "identity_id" field. +func IdentityIDNotNil() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotNull(FieldIdentityID)) +} + +// AdoptDisplayNameEQ applies the EQ predicate on the "adopt_display_name" field. +func AdoptDisplayNameEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v)) +} + +// AdoptDisplayNameNEQ applies the NEQ predicate on the "adopt_display_name" field. +func AdoptDisplayNameNEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptDisplayName, v)) +} + +// AdoptAvatarEQ applies the EQ predicate on the "adopt_avatar" field. +func AdoptAvatarEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v)) +} + +// AdoptAvatarNEQ applies the NEQ predicate on the "adopt_avatar" field. +func AdoptAvatarNEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptAvatar, v)) +} + +// DecidedAtEQ applies the EQ predicate on the "decided_at" field. +func DecidedAtEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v)) +} + +// DecidedAtNEQ applies the NEQ predicate on the "decided_at" field. +func DecidedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldDecidedAt, v)) +} + +// DecidedAtIn applies the In predicate on the "decided_at" field. +func DecidedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldDecidedAt, vs...)) +} + +// DecidedAtNotIn applies the NotIn predicate on the "decided_at" field. +func DecidedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldDecidedAt, vs...)) +} + +// DecidedAtGT applies the GT predicate on the "decided_at" field. +func DecidedAtGT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldDecidedAt, v)) +} + +// DecidedAtGTE applies the GTE predicate on the "decided_at" field. +func DecidedAtGTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldDecidedAt, v)) +} + +// DecidedAtLT applies the LT predicate on the "decided_at" field. +func DecidedAtLT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldDecidedAt, v)) +} + +// DecidedAtLTE applies the LTE predicate on the "decided_at" field. +func DecidedAtLTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldDecidedAt, v)) +} + +// HasPendingAuthSession applies the HasEdge predicate on the "pending_auth_session" edge. +func HasPendingAuthSession() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasPendingAuthSessionWith applies the HasEdge predicate on the "pending_auth_session" edge with a given conditions (other predicates). +func HasPendingAuthSessionWith(preds ...predicate.PendingAuthSession) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := newPendingAuthSessionStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasIdentity applies the HasEdge predicate on the "identity" edge. +func HasIdentity() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates). +func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := newIdentityStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.NotPredicates(p)) +} diff --git a/backend/ent/identityadoptiondecision_create.go b/backend/ent/identityadoptiondecision_create.go new file mode 100644 index 00000000..491ba9f9 --- /dev/null +++ b/backend/ent/identityadoptiondecision_create.go @@ -0,0 +1,843 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" +) + +// IdentityAdoptionDecisionCreate is the builder for creating a IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionCreate struct { + config + mutation *IdentityAdoptionDecisionMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *IdentityAdoptionDecisionCreate) SetCreatedAt(v time.Time) *IdentityAdoptionDecisionCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableCreatedAt(v *time.Time) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *IdentityAdoptionDecisionCreate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableUpdatedAt(v *time.Time) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionCreate { + _c.mutation.SetPendingAuthSessionID(v) + return _c +} + +// SetIdentityID sets the "identity_id" field. +func (_c *IdentityAdoptionDecisionCreate) SetIdentityID(v int64) *IdentityAdoptionDecisionCreate { + _c.mutation.SetIdentityID(v) + return _c +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetIdentityID(*v) + } + return _c +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (_c *IdentityAdoptionDecisionCreate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionCreate { + _c.mutation.SetAdoptDisplayName(v) + return _c +} + +// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetAdoptDisplayName(*v) + } + return _c +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (_c *IdentityAdoptionDecisionCreate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionCreate { + _c.mutation.SetAdoptAvatar(v) + return _c +} + +// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetAdoptAvatar(*v) + } + return _c +} + +// SetDecidedAt sets the "decided_at" field. +func (_c *IdentityAdoptionDecisionCreate) SetDecidedAt(v time.Time) *IdentityAdoptionDecisionCreate { + _c.mutation.SetDecidedAt(v) + return _c +} + +// SetNillableDecidedAt sets the "decided_at" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableDecidedAt(v *time.Time) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetDecidedAt(*v) + } + return _c +} + +// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity. +func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionCreate { + return _c.SetPendingAuthSessionID(v.ID) +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_c *IdentityAdoptionDecisionCreate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionCreate { + return _c.SetIdentityID(v.ID) +} + +// Mutation returns the IdentityAdoptionDecisionMutation object of the builder. +func (_c *IdentityAdoptionDecisionCreate) Mutation() *IdentityAdoptionDecisionMutation { + return _c.mutation +} + +// Save creates the IdentityAdoptionDecision in the database. +func (_c *IdentityAdoptionDecisionCreate) Save(ctx context.Context) (*IdentityAdoptionDecision, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *IdentityAdoptionDecisionCreate) SaveX(ctx context.Context) *IdentityAdoptionDecision { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdentityAdoptionDecisionCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdentityAdoptionDecisionCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *IdentityAdoptionDecisionCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := identityadoptiondecision.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := identityadoptiondecision.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.AdoptDisplayName(); !ok { + v := identityadoptiondecision.DefaultAdoptDisplayName + _c.mutation.SetAdoptDisplayName(v) + } + if _, ok := _c.mutation.AdoptAvatar(); !ok { + v := identityadoptiondecision.DefaultAdoptAvatar + _c.mutation.SetAdoptAvatar(v) + } + if _, ok := _c.mutation.DecidedAt(); !ok { + v := identityadoptiondecision.DefaultDecidedAt() + _c.mutation.SetDecidedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *IdentityAdoptionDecisionCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.updated_at"`)} + } + if _, ok := _c.mutation.PendingAuthSessionID(); !ok { + return &ValidationError{Name: "pending_auth_session_id", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.pending_auth_session_id"`)} + } + if _, ok := _c.mutation.AdoptDisplayName(); !ok { + return &ValidationError{Name: "adopt_display_name", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_display_name"`)} + } + if _, ok := _c.mutation.AdoptAvatar(); !ok { + return &ValidationError{Name: "adopt_avatar", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_avatar"`)} + } + if _, ok := _c.mutation.DecidedAt(); !ok { + return &ValidationError{Name: "decided_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.decided_at"`)} + } + if len(_c.mutation.PendingAuthSessionIDs()) == 0 { + return &ValidationError{Name: "pending_auth_session", err: errors.New(`ent: missing required edge "IdentityAdoptionDecision.pending_auth_session"`)} + } + return nil +} + +func (_c *IdentityAdoptionDecisionCreate) sqlSave(ctx context.Context) (*IdentityAdoptionDecision, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *IdentityAdoptionDecisionCreate) createSpec() (*IdentityAdoptionDecision, *sqlgraph.CreateSpec) { + var ( + _node = &IdentityAdoptionDecision{config: _c.config} + _spec = sqlgraph.NewCreateSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.AdoptDisplayName(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value) + _node.AdoptDisplayName = value + } + if value, ok := _c.mutation.AdoptAvatar(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value) + _node.AdoptAvatar = value + } + if value, ok := _c.mutation.DecidedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldDecidedAt, field.TypeTime, value) + _node.DecidedAt = value + } + if nodes := _c.mutation.PendingAuthSessionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.PendingAuthSessionID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.IdentityID = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdentityAdoptionDecision.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdentityAdoptionDecisionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreate) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertOne { + _c.conflict = opts + return &IdentityAdoptionDecisionUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreate) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdentityAdoptionDecisionUpsertOne{ + create: _c, + } +} + +type ( + // IdentityAdoptionDecisionUpsertOne is the builder for "upsert"-ing + // one IdentityAdoptionDecision node. + IdentityAdoptionDecisionUpsertOne struct { + create *IdentityAdoptionDecisionCreate + } + + // IdentityAdoptionDecisionUpsert is the "OnConflict" setter. + IdentityAdoptionDecisionUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdentityAdoptionDecisionUpsert) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldUpdatedAt) + return u +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (u *IdentityAdoptionDecisionUpsert) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldPendingAuthSessionID, v) + return u +} + +// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldPendingAuthSessionID) + return u +} + +// SetIdentityID sets the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsert) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldIdentityID, v) + return u +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateIdentityID() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldIdentityID) + return u +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsert) ClearIdentityID() *IdentityAdoptionDecisionUpsert { + u.SetNull(identityadoptiondecision.FieldIdentityID) + return u +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (u *IdentityAdoptionDecisionUpsert) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldAdoptDisplayName, v) + return u +} + +// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldAdoptDisplayName) + return u +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (u *IdentityAdoptionDecisionUpsert) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldAdoptAvatar, v) + return u +} + +// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldAdoptAvatar) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertOne) UpdateNewValues() *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldCreatedAt) + } + if _, exists := u.create.mutation.DecidedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldDecidedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertOne) Ignore() *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdentityAdoptionDecisionUpsertOne) DoNothing() *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreate.OnConflict +// documentation for more info. +func (u *IdentityAdoptionDecisionUpsertOne) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdentityAdoptionDecisionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetPendingAuthSessionID(v) + }) +} + +// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdatePendingAuthSessionID() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateIdentityID() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateIdentityID() + }) +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertOne) ClearIdentityID() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.ClearIdentityID() + }) +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptDisplayName(v) + }) +} + +// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptDisplayName() + }) +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptAvatar(v) + }) +} + +// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptAvatar() + }) +} + +// Exec executes the query. +func (u *IdentityAdoptionDecisionUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdentityAdoptionDecisionCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdentityAdoptionDecisionUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *IdentityAdoptionDecisionUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *IdentityAdoptionDecisionUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// IdentityAdoptionDecisionCreateBulk is the builder for creating many IdentityAdoptionDecision entities in bulk. +type IdentityAdoptionDecisionCreateBulk struct { + config + err error + builders []*IdentityAdoptionDecisionCreate + conflict []sql.ConflictOption +} + +// Save creates the IdentityAdoptionDecision entities in the database. +func (_c *IdentityAdoptionDecisionCreateBulk) Save(ctx context.Context) ([]*IdentityAdoptionDecision, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*IdentityAdoptionDecision, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*IdentityAdoptionDecisionMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *IdentityAdoptionDecisionCreateBulk) SaveX(ctx context.Context) []*IdentityAdoptionDecision { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdentityAdoptionDecisionCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdentityAdoptionDecisionCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdentityAdoptionDecision.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdentityAdoptionDecisionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreateBulk) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertBulk { + _c.conflict = opts + return &IdentityAdoptionDecisionUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreateBulk) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdentityAdoptionDecisionUpsertBulk{ + create: _c, + } +} + +// IdentityAdoptionDecisionUpsertBulk is the builder for "upsert"-ing +// a bulk of IdentityAdoptionDecision nodes. +type IdentityAdoptionDecisionUpsertBulk struct { + create *IdentityAdoptionDecisionCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateNewValues() *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldCreatedAt) + } + if _, exists := b.mutation.DecidedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldDecidedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertBulk) Ignore() *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdentityAdoptionDecisionUpsertBulk) DoNothing() *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreateBulk.OnConflict +// documentation for more info. +func (u *IdentityAdoptionDecisionUpsertBulk) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdentityAdoptionDecisionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetPendingAuthSessionID(v) + }) +} + +// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdatePendingAuthSessionID() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateIdentityID() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateIdentityID() + }) +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertBulk) ClearIdentityID() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.ClearIdentityID() + }) +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptDisplayName(v) + }) +} + +// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptDisplayName() + }) +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptAvatar(v) + }) +} + +// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptAvatar() + }) +} + +// Exec executes the query. +func (u *IdentityAdoptionDecisionUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the IdentityAdoptionDecisionCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdentityAdoptionDecisionCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdentityAdoptionDecisionUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/identityadoptiondecision_delete.go b/backend/ent/identityadoptiondecision_delete.go new file mode 100644 index 00000000..ef3d328d --- /dev/null +++ b/backend/ent/identityadoptiondecision_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdentityAdoptionDecisionDelete is the builder for deleting a IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionDelete struct { + config + hooks []Hook + mutation *IdentityAdoptionDecisionMutation +} + +// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder. +func (_d *IdentityAdoptionDecisionDelete) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *IdentityAdoptionDecisionDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdentityAdoptionDecisionDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *IdentityAdoptionDecisionDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// IdentityAdoptionDecisionDeleteOne is the builder for deleting a single IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionDeleteOne struct { + _d *IdentityAdoptionDecisionDelete +} + +// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder. +func (_d *IdentityAdoptionDecisionDeleteOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *IdentityAdoptionDecisionDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{identityadoptiondecision.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdentityAdoptionDecisionDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/identityadoptiondecision_query.go b/backend/ent/identityadoptiondecision_query.go new file mode 100644 index 00000000..4082d8ee --- /dev/null +++ b/backend/ent/identityadoptiondecision_query.go @@ -0,0 +1,721 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdentityAdoptionDecisionQuery is the builder for querying IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionQuery struct { + config + ctx *QueryContext + order []identityadoptiondecision.OrderOption + inters []Interceptor + predicates []predicate.IdentityAdoptionDecision + withPendingAuthSession *PendingAuthSessionQuery + withIdentity *AuthIdentityQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the IdentityAdoptionDecisionQuery builder. +func (_q *IdentityAdoptionDecisionQuery) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *IdentityAdoptionDecisionQuery) Limit(limit int) *IdentityAdoptionDecisionQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *IdentityAdoptionDecisionQuery) Offset(offset int) *IdentityAdoptionDecisionQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *IdentityAdoptionDecisionQuery) Unique(unique bool) *IdentityAdoptionDecisionQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *IdentityAdoptionDecisionQuery) Order(o ...identityadoptiondecision.OrderOption) *IdentityAdoptionDecisionQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryPendingAuthSession chains the current query on the "pending_auth_session" edge. +func (_q *IdentityAdoptionDecisionQuery) QueryPendingAuthSession() *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryIdentity chains the current query on the "identity" edge. +func (_q *IdentityAdoptionDecisionQuery) QueryIdentity() *AuthIdentityQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first IdentityAdoptionDecision entity from the query. +// Returns a *NotFoundError when no IdentityAdoptionDecision was found. +func (_q *IdentityAdoptionDecisionQuery) First(ctx context.Context) (*IdentityAdoptionDecision, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{identityadoptiondecision.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) FirstX(ctx context.Context) *IdentityAdoptionDecision { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first IdentityAdoptionDecision ID from the query. +// Returns a *NotFoundError when no IdentityAdoptionDecision ID was found. +func (_q *IdentityAdoptionDecisionQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{identityadoptiondecision.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single IdentityAdoptionDecision entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one IdentityAdoptionDecision entity is found. +// Returns a *NotFoundError when no IdentityAdoptionDecision entities are found. +func (_q *IdentityAdoptionDecisionQuery) Only(ctx context.Context) (*IdentityAdoptionDecision, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{identityadoptiondecision.Label} + default: + return nil, &NotSingularError{identityadoptiondecision.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) OnlyX(ctx context.Context) *IdentityAdoptionDecision { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only IdentityAdoptionDecision ID in the query. +// Returns a *NotSingularError when more than one IdentityAdoptionDecision ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *IdentityAdoptionDecisionQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{identityadoptiondecision.Label} + default: + err = &NotSingularError{identityadoptiondecision.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of IdentityAdoptionDecisions. +func (_q *IdentityAdoptionDecisionQuery) All(ctx context.Context) ([]*IdentityAdoptionDecision, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*IdentityAdoptionDecision, *IdentityAdoptionDecisionQuery]() + return withInterceptors[[]*IdentityAdoptionDecision](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) AllX(ctx context.Context) []*IdentityAdoptionDecision { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of IdentityAdoptionDecision IDs. +func (_q *IdentityAdoptionDecisionQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(identityadoptiondecision.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *IdentityAdoptionDecisionQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*IdentityAdoptionDecisionQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *IdentityAdoptionDecisionQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the IdentityAdoptionDecisionQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *IdentityAdoptionDecisionQuery) Clone() *IdentityAdoptionDecisionQuery { + if _q == nil { + return nil + } + return &IdentityAdoptionDecisionQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]identityadoptiondecision.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.IdentityAdoptionDecision{}, _q.predicates...), + withPendingAuthSession: _q.withPendingAuthSession.Clone(), + withIdentity: _q.withIdentity.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithPendingAuthSession tells the query-builder to eager-load the nodes that are connected to +// the "pending_auth_session" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *IdentityAdoptionDecisionQuery) WithPendingAuthSession(opts ...func(*PendingAuthSessionQuery)) *IdentityAdoptionDecisionQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withPendingAuthSession = query + return _q +} + +// WithIdentity tells the query-builder to eager-load the nodes that are connected to +// the "identity" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *IdentityAdoptionDecisionQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *IdentityAdoptionDecisionQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withIdentity = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.IdentityAdoptionDecision.Query(). +// GroupBy(identityadoptiondecision.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *IdentityAdoptionDecisionQuery) GroupBy(field string, fields ...string) *IdentityAdoptionDecisionGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &IdentityAdoptionDecisionGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = identityadoptiondecision.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.IdentityAdoptionDecision.Query(). +// Select(identityadoptiondecision.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *IdentityAdoptionDecisionQuery) Select(fields ...string) *IdentityAdoptionDecisionSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &IdentityAdoptionDecisionSelect{IdentityAdoptionDecisionQuery: _q} + sbuild.label = identityadoptiondecision.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a IdentityAdoptionDecisionSelect configured with the given aggregations. +func (_q *IdentityAdoptionDecisionQuery) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *IdentityAdoptionDecisionQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !identityadoptiondecision.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *IdentityAdoptionDecisionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IdentityAdoptionDecision, error) { + var ( + nodes = []*IdentityAdoptionDecision{} + _spec = _q.querySpec() + loadedTypes = [2]bool{ + _q.withPendingAuthSession != nil, + _q.withIdentity != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*IdentityAdoptionDecision).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &IdentityAdoptionDecision{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withPendingAuthSession; query != nil { + if err := _q.loadPendingAuthSession(ctx, query, nodes, nil, + func(n *IdentityAdoptionDecision, e *PendingAuthSession) { n.Edges.PendingAuthSession = e }); err != nil { + return nil, err + } + } + if query := _q.withIdentity; query != nil { + if err := _q.loadIdentity(ctx, query, nodes, nil, + func(n *IdentityAdoptionDecision, e *AuthIdentity) { n.Edges.Identity = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *IdentityAdoptionDecisionQuery) loadPendingAuthSession(ctx context.Context, query *PendingAuthSessionQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *PendingAuthSession)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*IdentityAdoptionDecision) + for i := range nodes { + fk := nodes[i].PendingAuthSessionID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(pendingauthsession.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "pending_auth_session_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *IdentityAdoptionDecisionQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *AuthIdentity)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*IdentityAdoptionDecision) + for i := range nodes { + if nodes[i].IdentityID == nil { + continue + } + fk := *nodes[i].IdentityID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(authidentity.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *IdentityAdoptionDecisionQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *IdentityAdoptionDecisionQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID) + for i := range fields { + if fields[i] != identityadoptiondecision.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withPendingAuthSession != nil { + _spec.Node.AddColumnOnce(identityadoptiondecision.FieldPendingAuthSessionID) + } + if _q.withIdentity != nil { + _spec.Node.AddColumnOnce(identityadoptiondecision.FieldIdentityID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *IdentityAdoptionDecisionQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(identityadoptiondecision.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = identityadoptiondecision.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *IdentityAdoptionDecisionQuery) ForUpdate(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *IdentityAdoptionDecisionQuery) ForShare(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// IdentityAdoptionDecisionGroupBy is the group-by builder for IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionGroupBy struct { + selector + build *IdentityAdoptionDecisionQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *IdentityAdoptionDecisionGroupBy) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *IdentityAdoptionDecisionGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *IdentityAdoptionDecisionGroupBy) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// IdentityAdoptionDecisionSelect is the builder for selecting fields of IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionSelect struct { + *IdentityAdoptionDecisionQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *IdentityAdoptionDecisionSelect) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *IdentityAdoptionDecisionSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionSelect](ctx, _s.IdentityAdoptionDecisionQuery, _s, _s.inters, v) +} + +func (_s *IdentityAdoptionDecisionSelect) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/identityadoptiondecision_update.go b/backend/ent/identityadoptiondecision_update.go new file mode 100644 index 00000000..0ca21d27 --- /dev/null +++ b/backend/ent/identityadoptiondecision_update.go @@ -0,0 +1,532 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdentityAdoptionDecisionUpdate is the builder for updating IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionUpdate struct { + config + hooks []Hook + mutation *IdentityAdoptionDecisionMutation +} + +// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder. +func (_u *IdentityAdoptionDecisionUpdate) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdentityAdoptionDecisionUpdate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetPendingAuthSessionID(v) + return _u +} + +// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetPendingAuthSessionID(*v) + } + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdate) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdate) ClearIdentityID() *IdentityAdoptionDecisionUpdate { + _u.mutation.ClearIdentityID() + return _u +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (_u *IdentityAdoptionDecisionUpdate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetAdoptDisplayName(v) + return _u +} + +// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetAdoptDisplayName(*v) + } + return _u +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (_u *IdentityAdoptionDecisionUpdate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetAdoptAvatar(v) + return _u +} + +// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetAdoptAvatar(*v) + } + return _u +} + +// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdate { + return _u.SetPendingAuthSessionID(v.ID) +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdate { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the IdentityAdoptionDecisionMutation object of the builder. +func (_u *IdentityAdoptionDecisionUpdate) Mutation() *IdentityAdoptionDecisionMutation { + return _u.mutation +} + +// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdate) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdate { + _u.mutation.ClearPendingAuthSession() + return _u +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdate) ClearIdentity() *IdentityAdoptionDecisionUpdate { + _u.mutation.ClearIdentity() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *IdentityAdoptionDecisionUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *IdentityAdoptionDecisionUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdentityAdoptionDecisionUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := identityadoptiondecision.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdentityAdoptionDecisionUpdate) check() error { + if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`) + } + return nil +} + +func (_u *IdentityAdoptionDecisionUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.AdoptDisplayName(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value) + } + if value, ok := _u.mutation.AdoptAvatar(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value) + } + if _u.mutation.PendingAuthSessionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{identityadoptiondecision.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// IdentityAdoptionDecisionUpdateOne is the builder for updating a single IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionUpdateOne struct { + config + fields []string + hooks []Hook + mutation *IdentityAdoptionDecisionMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetPendingAuthSessionID(v) + return _u +} + +// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetPendingAuthSessionID(*v) + } + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentityID() *IdentityAdoptionDecisionUpdateOne { + _u.mutation.ClearIdentityID() + return _u +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetAdoptDisplayName(v) + return _u +} + +// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetAdoptDisplayName(*v) + } + return _u +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetAdoptAvatar(v) + return _u +} + +// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetAdoptAvatar(*v) + } + return _u +} + +// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdateOne { + return _u.SetPendingAuthSessionID(v.ID) +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdateOne { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the IdentityAdoptionDecisionMutation object of the builder. +func (_u *IdentityAdoptionDecisionUpdateOne) Mutation() *IdentityAdoptionDecisionMutation { + return _u.mutation +} + +// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdateOne) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdateOne { + _u.mutation.ClearPendingAuthSession() + return _u +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentity() *IdentityAdoptionDecisionUpdateOne { + _u.mutation.ClearIdentity() + return _u +} + +// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder. +func (_u *IdentityAdoptionDecisionUpdateOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *IdentityAdoptionDecisionUpdateOne) Select(field string, fields ...string) *IdentityAdoptionDecisionUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated IdentityAdoptionDecision entity. +func (_u *IdentityAdoptionDecisionUpdateOne) Save(ctx context.Context) (*IdentityAdoptionDecision, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdateOne) SaveX(ctx context.Context) *IdentityAdoptionDecision { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *IdentityAdoptionDecisionUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdentityAdoptionDecisionUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := identityadoptiondecision.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdentityAdoptionDecisionUpdateOne) check() error { + if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`) + } + return nil +} + +func (_u *IdentityAdoptionDecisionUpdateOne) sqlSave(ctx context.Context) (_node *IdentityAdoptionDecision, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "IdentityAdoptionDecision.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID) + for _, f := range fields { + if !identityadoptiondecision.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != identityadoptiondecision.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.AdoptDisplayName(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value) + } + if value, ok := _u.mutation.AdoptAvatar(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value) + } + if _u.mutation.PendingAuthSessionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &IdentityAdoptionDecision{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{identityadoptiondecision.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index 8d8320bb..157c5122 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -13,12 +13,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -228,6 +232,60 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q) } +// The AuthIdentityFunc type is an adapter to allow the use of ordinary function as a Querier. +type AuthIdentityFunc func(context.Context, *ent.AuthIdentityQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f AuthIdentityFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.AuthIdentityQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q) +} + +// The TraverseAuthIdentity type is an adapter to allow the use of ordinary function as Traverser. +type TraverseAuthIdentity func(context.Context, *ent.AuthIdentityQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseAuthIdentity) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseAuthIdentity) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.AuthIdentityQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q) +} + +// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary function as a Querier. +type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f AuthIdentityChannelFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.AuthIdentityChannelQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q) +} + +// The TraverseAuthIdentityChannel type is an adapter to allow the use of ordinary function as Traverser. +type TraverseAuthIdentityChannel func(context.Context, *ent.AuthIdentityChannelQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseAuthIdentityChannel) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseAuthIdentityChannel) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.AuthIdentityChannelQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q) +} + // The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier. type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error) @@ -309,6 +367,33 @@ func (f TraverseIdempotencyRecord) Traverse(ctx context.Context, q ent.Query) er return fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q) } +// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary function as a Querier. +type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f IdentityAdoptionDecisionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q) +} + +// The TraverseIdentityAdoptionDecision type is an adapter to allow the use of ordinary function as Traverser. +type TraverseIdentityAdoptionDecision func(context.Context, *ent.IdentityAdoptionDecisionQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseIdentityAdoptionDecision) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseIdentityAdoptionDecision) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q) +} + // The PaymentAuditLogFunc type is an adapter to allow the use of ordinary function as a Querier. type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogQuery) (ent.Value, error) @@ -390,6 +475,33 @@ func (f TraversePaymentProviderInstance) Traverse(ctx context.Context, q ent.Que return fmt.Errorf("unexpected query type %T. expect *ent.PaymentProviderInstanceQuery", q) } +// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary function as a Querier. +type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f PendingAuthSessionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.PendingAuthSessionQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q) +} + +// The TraversePendingAuthSession type is an adapter to allow the use of ordinary function as Traverser. +type TraversePendingAuthSession func(context.Context, *ent.PendingAuthSessionQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraversePendingAuthSession) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraversePendingAuthSession) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.PendingAuthSessionQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier. type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error) @@ -808,18 +920,26 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil case *ent.AnnouncementReadQuery: return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil + case *ent.AuthIdentityQuery: + return &query[*ent.AuthIdentityQuery, predicate.AuthIdentity, authidentity.OrderOption]{typ: ent.TypeAuthIdentity, tq: q}, nil + case *ent.AuthIdentityChannelQuery: + return &query[*ent.AuthIdentityChannelQuery, predicate.AuthIdentityChannel, authidentitychannel.OrderOption]{typ: ent.TypeAuthIdentityChannel, tq: q}, nil case *ent.ErrorPassthroughRuleQuery: return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil case *ent.GroupQuery: return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil case *ent.IdempotencyRecordQuery: return &query[*ent.IdempotencyRecordQuery, predicate.IdempotencyRecord, idempotencyrecord.OrderOption]{typ: ent.TypeIdempotencyRecord, tq: q}, nil + case *ent.IdentityAdoptionDecisionQuery: + return &query[*ent.IdentityAdoptionDecisionQuery, predicate.IdentityAdoptionDecision, identityadoptiondecision.OrderOption]{typ: ent.TypeIdentityAdoptionDecision, tq: q}, nil case *ent.PaymentAuditLogQuery: return &query[*ent.PaymentAuditLogQuery, predicate.PaymentAuditLog, paymentauditlog.OrderOption]{typ: ent.TypePaymentAuditLog, tq: q}, nil case *ent.PaymentOrderQuery: return &query[*ent.PaymentOrderQuery, predicate.PaymentOrder, paymentorder.OrderOption]{typ: ent.TypePaymentOrder, tq: q}, nil case *ent.PaymentProviderInstanceQuery: return &query[*ent.PaymentProviderInstanceQuery, predicate.PaymentProviderInstance, paymentproviderinstance.OrderOption]{typ: ent.TypePaymentProviderInstance, tq: q}, nil + case *ent.PendingAuthSessionQuery: + return &query[*ent.PendingAuthSessionQuery, predicate.PendingAuthSession, pendingauthsession.OrderOption]{typ: ent.TypePendingAuthSession, tq: q}, nil case *ent.PromoCodeQuery: return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil case *ent.PromoCodeUsageQuery: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 68bdbf55..bf41e73b 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -338,6 +338,89 @@ var ( }, }, } + // AuthIdentitiesColumns holds the columns for the "auth_identities" table. + AuthIdentitiesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "provider_type", Type: field.TypeString, Size: 20}, + {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "issuer", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "user_id", Type: field.TypeInt64}, + } + // AuthIdentitiesTable holds the schema information for the "auth_identities" table. + AuthIdentitiesTable = &schema.Table{ + Name: "auth_identities", + Columns: AuthIdentitiesColumns, + PrimaryKey: []*schema.Column{AuthIdentitiesColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "auth_identities_users_auth_identities", + Columns: []*schema.Column{AuthIdentitiesColumns[9]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "authidentity_provider_type_provider_key_provider_subject", + Unique: true, + Columns: []*schema.Column{AuthIdentitiesColumns[3], AuthIdentitiesColumns[4], AuthIdentitiesColumns[5]}, + }, + { + Name: "authidentity_user_id", + Unique: false, + Columns: []*schema.Column{AuthIdentitiesColumns[9]}, + }, + { + Name: "authidentity_user_id_provider_type", + Unique: false, + Columns: []*schema.Column{AuthIdentitiesColumns[9], AuthIdentitiesColumns[3]}, + }, + }, + } + // AuthIdentityChannelsColumns holds the columns for the "auth_identity_channels" table. + AuthIdentityChannelsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "provider_type", Type: field.TypeString, Size: 20}, + {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "channel", Type: field.TypeString, Size: 20}, + {Name: "channel_app_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "channel_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "identity_id", Type: field.TypeInt64}, + } + // AuthIdentityChannelsTable holds the schema information for the "auth_identity_channels" table. + AuthIdentityChannelsTable = &schema.Table{ + Name: "auth_identity_channels", + Columns: AuthIdentityChannelsColumns, + PrimaryKey: []*schema.Column{AuthIdentityChannelsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "auth_identity_channels_auth_identities_channels", + Columns: []*schema.Column{AuthIdentityChannelsColumns[9]}, + RefColumns: []*schema.Column{AuthIdentitiesColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "authidentitychannel_provider_type_provider_key_channel_channel_app_id_channel_subject", + Unique: true, + Columns: []*schema.Column{AuthIdentityChannelsColumns[3], AuthIdentityChannelsColumns[4], AuthIdentityChannelsColumns[5], AuthIdentityChannelsColumns[6], AuthIdentityChannelsColumns[7]}, + }, + { + Name: "authidentitychannel_identity_id", + Unique: false, + Columns: []*schema.Column{AuthIdentityChannelsColumns[9]}, + }, + }, + } // ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table. ErrorPassthroughRulesColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -485,6 +568,49 @@ var ( }, }, } + // IdentityAdoptionDecisionsColumns holds the columns for the "identity_adoption_decisions" table. + IdentityAdoptionDecisionsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "adopt_display_name", Type: field.TypeBool, Default: false}, + {Name: "adopt_avatar", Type: field.TypeBool, Default: false}, + {Name: "decided_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "identity_id", Type: field.TypeInt64, Nullable: true}, + {Name: "pending_auth_session_id", Type: field.TypeInt64, Unique: true}, + } + // IdentityAdoptionDecisionsTable holds the schema information for the "identity_adoption_decisions" table. + IdentityAdoptionDecisionsTable = &schema.Table{ + Name: "identity_adoption_decisions", + Columns: IdentityAdoptionDecisionsColumns, + PrimaryKey: []*schema.Column{IdentityAdoptionDecisionsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "identity_adoption_decisions_auth_identities_adoption_decisions", + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]}, + RefColumns: []*schema.Column{AuthIdentitiesColumns[0]}, + OnDelete: schema.SetNull, + }, + { + Symbol: "identity_adoption_decisions_pending_auth_sessions_adoption_decision", + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]}, + RefColumns: []*schema.Column{PendingAuthSessionsColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "identityadoptiondecision_pending_auth_session_id", + Unique: true, + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]}, + }, + { + Name: "identityadoptiondecision_identity_id", + Unique: false, + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]}, + }, + }, + } // PaymentAuditLogsColumns holds the columns for the "payment_audit_logs" table. PaymentAuditLogsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -638,6 +764,72 @@ var ( }, }, } + // PendingAuthSessionsColumns holds the columns for the "pending_auth_sessions" table. + PendingAuthSessionsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "session_token", Type: field.TypeString, Size: 255}, + {Name: "intent", Type: field.TypeString, Size: 40}, + {Name: "provider_type", Type: field.TypeString, Size: 20}, + {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "redirect_to", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "resolved_email", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "registration_password_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "upstream_identity_claims", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "local_flow_state", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "browser_session_key", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "completion_code_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "completion_code_expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "email_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "password_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "totp_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "expires_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "consumed_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "target_user_id", Type: field.TypeInt64, Nullable: true}, + } + // PendingAuthSessionsTable holds the schema information for the "pending_auth_sessions" table. + PendingAuthSessionsTable = &schema.Table{ + Name: "pending_auth_sessions", + Columns: PendingAuthSessionsColumns, + PrimaryKey: []*schema.Column{PendingAuthSessionsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "pending_auth_sessions_users_pending_auth_sessions", + Columns: []*schema.Column{PendingAuthSessionsColumns[21]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.SetNull, + }, + }, + Indexes: []*schema.Index{ + { + Name: "pendingauthsession_session_token", + Unique: true, + Columns: []*schema.Column{PendingAuthSessionsColumns[3]}, + }, + { + Name: "pendingauthsession_target_user_id", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[21]}, + }, + { + Name: "pendingauthsession_expires_at", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[19]}, + }, + { + Name: "pendingauthsession_provider_type_provider_key_provider_subject", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[5], PendingAuthSessionsColumns[6], PendingAuthSessionsColumns[7]}, + }, + { + Name: "pendingauthsession_completion_code_hash", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[14]}, + }, + }, + } // PromoCodesColumns holds the columns for the "promo_codes" table. PromoCodesColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -1079,6 +1271,9 @@ var ( {Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "totp_enabled", Type: field.TypeBool, Default: false}, {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true}, + {Name: "signup_source", Type: field.TypeString, Size: 20, Default: "email"}, + {Name: "last_login_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "last_active_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "balance_notify_enabled", Type: field.TypeBool, Default: true}, {Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"}, {Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, @@ -1318,12 +1513,16 @@ var ( AccountGroupsTable, AnnouncementsTable, AnnouncementReadsTable, + AuthIdentitiesTable, + AuthIdentityChannelsTable, ErrorPassthroughRulesTable, GroupsTable, IdempotencyRecordsTable, + IdentityAdoptionDecisionsTable, PaymentAuditLogsTable, PaymentOrdersTable, PaymentProviderInstancesTable, + PendingAuthSessionsTable, PromoCodesTable, PromoCodeUsagesTable, ProxiesTable, @@ -1365,6 +1564,14 @@ func init() { AnnouncementReadsTable.Annotation = &entsql.Annotation{ Table: "announcement_reads", } + AuthIdentitiesTable.ForeignKeys[0].RefTable = UsersTable + AuthIdentitiesTable.Annotation = &entsql.Annotation{ + Table: "auth_identities", + } + AuthIdentityChannelsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable + AuthIdentityChannelsTable.Annotation = &entsql.Annotation{ + Table: "auth_identity_channels", + } ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{ Table: "error_passthrough_rules", } @@ -1374,6 +1581,11 @@ func init() { IdempotencyRecordsTable.Annotation = &entsql.Annotation{ Table: "idempotency_records", } + IdentityAdoptionDecisionsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable + IdentityAdoptionDecisionsTable.ForeignKeys[1].RefTable = PendingAuthSessionsTable + IdentityAdoptionDecisionsTable.Annotation = &entsql.Annotation{ + Table: "identity_adoption_decisions", + } PaymentAuditLogsTable.Annotation = &entsql.Annotation{ Table: "payment_audit_logs", } @@ -1384,6 +1596,10 @@ func init() { PaymentProviderInstancesTable.Annotation = &entsql.Annotation{ Table: "payment_provider_instances", } + PendingAuthSessionsTable.ForeignKeys[0].RefTable = UsersTable + PendingAuthSessionsTable.Annotation = &entsql.Annotation{ + Table: "pending_auth_sessions", + } PromoCodesTable.Annotation = &entsql.Annotation{ Table: "promo_codes", } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 524ccb92..12905c9a 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -17,12 +17,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -51,32 +55,36 @@ const ( OpUpdateOne = ent.OpUpdateOne // Node types. - TypeAPIKey = "APIKey" - TypeAccount = "Account" - TypeAccountGroup = "AccountGroup" - TypeAnnouncement = "Announcement" - TypeAnnouncementRead = "AnnouncementRead" - TypeErrorPassthroughRule = "ErrorPassthroughRule" - TypeGroup = "Group" - TypeIdempotencyRecord = "IdempotencyRecord" - TypePaymentAuditLog = "PaymentAuditLog" - TypePaymentOrder = "PaymentOrder" - TypePaymentProviderInstance = "PaymentProviderInstance" - TypePromoCode = "PromoCode" - TypePromoCodeUsage = "PromoCodeUsage" - TypeProxy = "Proxy" - TypeRedeemCode = "RedeemCode" - TypeSecuritySecret = "SecuritySecret" - TypeSetting = "Setting" - TypeSubscriptionPlan = "SubscriptionPlan" - TypeTLSFingerprintProfile = "TLSFingerprintProfile" - TypeUsageCleanupTask = "UsageCleanupTask" - TypeUsageLog = "UsageLog" - TypeUser = "User" - TypeUserAllowedGroup = "UserAllowedGroup" - TypeUserAttributeDefinition = "UserAttributeDefinition" - TypeUserAttributeValue = "UserAttributeValue" - TypeUserSubscription = "UserSubscription" + TypeAPIKey = "APIKey" + TypeAccount = "Account" + TypeAccountGroup = "AccountGroup" + TypeAnnouncement = "Announcement" + TypeAnnouncementRead = "AnnouncementRead" + TypeAuthIdentity = "AuthIdentity" + TypeAuthIdentityChannel = "AuthIdentityChannel" + TypeErrorPassthroughRule = "ErrorPassthroughRule" + TypeGroup = "Group" + TypeIdempotencyRecord = "IdempotencyRecord" + TypeIdentityAdoptionDecision = "IdentityAdoptionDecision" + TypePaymentAuditLog = "PaymentAuditLog" + TypePaymentOrder = "PaymentOrder" + TypePaymentProviderInstance = "PaymentProviderInstance" + TypePendingAuthSession = "PendingAuthSession" + TypePromoCode = "PromoCode" + TypePromoCodeUsage = "PromoCodeUsage" + TypeProxy = "Proxy" + TypeRedeemCode = "RedeemCode" + TypeSecuritySecret = "SecuritySecret" + TypeSetting = "Setting" + TypeSubscriptionPlan = "SubscriptionPlan" + TypeTLSFingerprintProfile = "TLSFingerprintProfile" + TypeUsageCleanupTask = "UsageCleanupTask" + TypeUsageLog = "UsageLog" + TypeUser = "User" + TypeUserAllowedGroup = "UserAllowedGroup" + TypeUserAttributeDefinition = "UserAttributeDefinition" + TypeUserAttributeValue = "UserAttributeValue" + TypeUserSubscription = "UserSubscription" ) // APIKeyMutation represents an operation that mutates the APIKey nodes in the graph. @@ -6887,6 +6895,1845 @@ func (m *AnnouncementReadMutation) ResetEdge(name string) error { return fmt.Errorf("unknown AnnouncementRead edge %s", name) } +// AuthIdentityMutation represents an operation that mutates the AuthIdentity nodes in the graph. +type AuthIdentityMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + provider_type *string + provider_key *string + provider_subject *string + verified_at *time.Time + issuer *string + metadata *map[string]interface{} + clearedFields map[string]struct{} + user *int64 + cleareduser bool + channels map[int64]struct{} + removedchannels map[int64]struct{} + clearedchannels bool + adoption_decisions map[int64]struct{} + removedadoption_decisions map[int64]struct{} + clearedadoption_decisions bool + done bool + oldValue func(context.Context) (*AuthIdentity, error) + predicates []predicate.AuthIdentity +} + +var _ ent.Mutation = (*AuthIdentityMutation)(nil) + +// authidentityOption allows management of the mutation configuration using functional options. +type authidentityOption func(*AuthIdentityMutation) + +// newAuthIdentityMutation creates new mutation for the AuthIdentity entity. +func newAuthIdentityMutation(c config, op Op, opts ...authidentityOption) *AuthIdentityMutation { + m := &AuthIdentityMutation{ + config: c, + op: op, + typ: TypeAuthIdentity, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withAuthIdentityID sets the ID field of the mutation. +func withAuthIdentityID(id int64) authidentityOption { + return func(m *AuthIdentityMutation) { + var ( + err error + once sync.Once + value *AuthIdentity + ) + m.oldValue = func(ctx context.Context) (*AuthIdentity, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().AuthIdentity.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withAuthIdentity sets the old AuthIdentity of the mutation. +func withAuthIdentity(node *AuthIdentity) authidentityOption { + return func(m *AuthIdentityMutation) { + m.oldValue = func(context.Context) (*AuthIdentity, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m AuthIdentityMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m AuthIdentityMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *AuthIdentityMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *AuthIdentityMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().AuthIdentity.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *AuthIdentityMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *AuthIdentityMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *AuthIdentityMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *AuthIdentityMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *AuthIdentityMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *AuthIdentityMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetUserID sets the "user_id" field. +func (m *AuthIdentityMutation) SetUserID(i int64) { + m.user = &i +} + +// UserID returns the value of the "user_id" field in the mutation. +func (m *AuthIdentityMutation) UserID() (r int64, exists bool) { + v := m.user + if v == nil { + return + } + return *v, true +} + +// OldUserID returns the old "user_id" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldUserID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUserID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUserID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUserID: %w", err) + } + return oldValue.UserID, nil +} + +// ResetUserID resets all changes to the "user_id" field. +func (m *AuthIdentityMutation) ResetUserID() { + m.user = nil +} + +// SetProviderType sets the "provider_type" field. +func (m *AuthIdentityMutation) SetProviderType(s string) { + m.provider_type = &s +} + +// ProviderType returns the value of the "provider_type" field in the mutation. +func (m *AuthIdentityMutation) ProviderType() (r string, exists bool) { + v := m.provider_type + if v == nil { + return + } + return *v, true +} + +// OldProviderType returns the old "provider_type" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldProviderType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderType: %w", err) + } + return oldValue.ProviderType, nil +} + +// ResetProviderType resets all changes to the "provider_type" field. +func (m *AuthIdentityMutation) ResetProviderType() { + m.provider_type = nil +} + +// SetProviderKey sets the "provider_key" field. +func (m *AuthIdentityMutation) SetProviderKey(s string) { + m.provider_key = &s +} + +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *AuthIdentityMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key + if v == nil { + return + } + return *v, true +} + +// OldProviderKey returns the old "provider_key" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldProviderKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + } + return oldValue.ProviderKey, nil +} + +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *AuthIdentityMutation) ResetProviderKey() { + m.provider_key = nil +} + +// SetProviderSubject sets the "provider_subject" field. +func (m *AuthIdentityMutation) SetProviderSubject(s string) { + m.provider_subject = &s +} + +// ProviderSubject returns the value of the "provider_subject" field in the mutation. +func (m *AuthIdentityMutation) ProviderSubject() (r string, exists bool) { + v := m.provider_subject + if v == nil { + return + } + return *v, true +} + +// OldProviderSubject returns the old "provider_subject" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldProviderSubject(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderSubject requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err) + } + return oldValue.ProviderSubject, nil +} + +// ResetProviderSubject resets all changes to the "provider_subject" field. +func (m *AuthIdentityMutation) ResetProviderSubject() { + m.provider_subject = nil +} + +// SetVerifiedAt sets the "verified_at" field. +func (m *AuthIdentityMutation) SetVerifiedAt(t time.Time) { + m.verified_at = &t +} + +// VerifiedAt returns the value of the "verified_at" field in the mutation. +func (m *AuthIdentityMutation) VerifiedAt() (r time.Time, exists bool) { + v := m.verified_at + if v == nil { + return + } + return *v, true +} + +// OldVerifiedAt returns the old "verified_at" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldVerifiedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldVerifiedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldVerifiedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldVerifiedAt: %w", err) + } + return oldValue.VerifiedAt, nil +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (m *AuthIdentityMutation) ClearVerifiedAt() { + m.verified_at = nil + m.clearedFields[authidentity.FieldVerifiedAt] = struct{}{} +} + +// VerifiedAtCleared returns if the "verified_at" field was cleared in this mutation. +func (m *AuthIdentityMutation) VerifiedAtCleared() bool { + _, ok := m.clearedFields[authidentity.FieldVerifiedAt] + return ok +} + +// ResetVerifiedAt resets all changes to the "verified_at" field. +func (m *AuthIdentityMutation) ResetVerifiedAt() { + m.verified_at = nil + delete(m.clearedFields, authidentity.FieldVerifiedAt) +} + +// SetIssuer sets the "issuer" field. +func (m *AuthIdentityMutation) SetIssuer(s string) { + m.issuer = &s +} + +// Issuer returns the value of the "issuer" field in the mutation. +func (m *AuthIdentityMutation) Issuer() (r string, exists bool) { + v := m.issuer + if v == nil { + return + } + return *v, true +} + +// OldIssuer returns the old "issuer" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldIssuer(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIssuer is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIssuer requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIssuer: %w", err) + } + return oldValue.Issuer, nil +} + +// ClearIssuer clears the value of the "issuer" field. +func (m *AuthIdentityMutation) ClearIssuer() { + m.issuer = nil + m.clearedFields[authidentity.FieldIssuer] = struct{}{} +} + +// IssuerCleared returns if the "issuer" field was cleared in this mutation. +func (m *AuthIdentityMutation) IssuerCleared() bool { + _, ok := m.clearedFields[authidentity.FieldIssuer] + return ok +} + +// ResetIssuer resets all changes to the "issuer" field. +func (m *AuthIdentityMutation) ResetIssuer() { + m.issuer = nil + delete(m.clearedFields, authidentity.FieldIssuer) +} + +// SetMetadata sets the "metadata" field. +func (m *AuthIdentityMutation) SetMetadata(value map[string]interface{}) { + m.metadata = &value +} + +// Metadata returns the value of the "metadata" field in the mutation. +func (m *AuthIdentityMutation) Metadata() (r map[string]interface{}, exists bool) { + v := m.metadata + if v == nil { + return + } + return *v, true +} + +// OldMetadata returns the old "metadata" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMetadata is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMetadata requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMetadata: %w", err) + } + return oldValue.Metadata, nil +} + +// ResetMetadata resets all changes to the "metadata" field. +func (m *AuthIdentityMutation) ResetMetadata() { + m.metadata = nil +} + +// ClearUser clears the "user" edge to the User entity. +func (m *AuthIdentityMutation) ClearUser() { + m.cleareduser = true + m.clearedFields[authidentity.FieldUserID] = struct{}{} +} + +// UserCleared reports if the "user" edge to the User entity was cleared. +func (m *AuthIdentityMutation) UserCleared() bool { + return m.cleareduser +} + +// UserIDs returns the "user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// UserID instead. It exists only for internal usage by the builders. +func (m *AuthIdentityMutation) UserIDs() (ids []int64) { + if id := m.user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetUser resets all changes to the "user" edge. +func (m *AuthIdentityMutation) ResetUser() { + m.user = nil + m.cleareduser = false +} + +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by ids. +func (m *AuthIdentityMutation) AddChannelIDs(ids ...int64) { + if m.channels == nil { + m.channels = make(map[int64]struct{}) + } + for i := range ids { + m.channels[ids[i]] = struct{}{} + } +} + +// ClearChannels clears the "channels" edge to the AuthIdentityChannel entity. +func (m *AuthIdentityMutation) ClearChannels() { + m.clearedchannels = true +} + +// ChannelsCleared reports if the "channels" edge to the AuthIdentityChannel entity was cleared. +func (m *AuthIdentityMutation) ChannelsCleared() bool { + return m.clearedchannels +} + +// RemoveChannelIDs removes the "channels" edge to the AuthIdentityChannel entity by IDs. +func (m *AuthIdentityMutation) RemoveChannelIDs(ids ...int64) { + if m.removedchannels == nil { + m.removedchannels = make(map[int64]struct{}) + } + for i := range ids { + delete(m.channels, ids[i]) + m.removedchannels[ids[i]] = struct{}{} + } +} + +// RemovedChannels returns the removed IDs of the "channels" edge to the AuthIdentityChannel entity. +func (m *AuthIdentityMutation) RemovedChannelsIDs() (ids []int64) { + for id := range m.removedchannels { + ids = append(ids, id) + } + return +} + +// ChannelsIDs returns the "channels" edge IDs in the mutation. +func (m *AuthIdentityMutation) ChannelsIDs() (ids []int64) { + for id := range m.channels { + ids = append(ids, id) + } + return +} + +// ResetChannels resets all changes to the "channels" edge. +func (m *AuthIdentityMutation) ResetChannels() { + m.channels = nil + m.clearedchannels = false + m.removedchannels = nil +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by ids. +func (m *AuthIdentityMutation) AddAdoptionDecisionIDs(ids ...int64) { + if m.adoption_decisions == nil { + m.adoption_decisions = make(map[int64]struct{}) + } + for i := range ids { + m.adoption_decisions[ids[i]] = struct{}{} + } +} + +// ClearAdoptionDecisions clears the "adoption_decisions" edge to the IdentityAdoptionDecision entity. +func (m *AuthIdentityMutation) ClearAdoptionDecisions() { + m.clearedadoption_decisions = true +} + +// AdoptionDecisionsCleared reports if the "adoption_decisions" edge to the IdentityAdoptionDecision entity was cleared. +func (m *AuthIdentityMutation) AdoptionDecisionsCleared() bool { + return m.clearedadoption_decisions +} + +// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (m *AuthIdentityMutation) RemoveAdoptionDecisionIDs(ids ...int64) { + if m.removedadoption_decisions == nil { + m.removedadoption_decisions = make(map[int64]struct{}) + } + for i := range ids { + delete(m.adoption_decisions, ids[i]) + m.removedadoption_decisions[ids[i]] = struct{}{} + } +} + +// RemovedAdoptionDecisions returns the removed IDs of the "adoption_decisions" edge to the IdentityAdoptionDecision entity. +func (m *AuthIdentityMutation) RemovedAdoptionDecisionsIDs() (ids []int64) { + for id := range m.removedadoption_decisions { + ids = append(ids, id) + } + return +} + +// AdoptionDecisionsIDs returns the "adoption_decisions" edge IDs in the mutation. +func (m *AuthIdentityMutation) AdoptionDecisionsIDs() (ids []int64) { + for id := range m.adoption_decisions { + ids = append(ids, id) + } + return +} + +// ResetAdoptionDecisions resets all changes to the "adoption_decisions" edge. +func (m *AuthIdentityMutation) ResetAdoptionDecisions() { + m.adoption_decisions = nil + m.clearedadoption_decisions = false + m.removedadoption_decisions = nil +} + +// Where appends a list predicates to the AuthIdentityMutation builder. +func (m *AuthIdentityMutation) Where(ps ...predicate.AuthIdentity) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the AuthIdentityMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AuthIdentityMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.AuthIdentity, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *AuthIdentityMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *AuthIdentityMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (AuthIdentity). +func (m *AuthIdentityMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *AuthIdentityMutation) Fields() []string { + fields := make([]string, 0, 9) + if m.created_at != nil { + fields = append(fields, authidentity.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, authidentity.FieldUpdatedAt) + } + if m.user != nil { + fields = append(fields, authidentity.FieldUserID) + } + if m.provider_type != nil { + fields = append(fields, authidentity.FieldProviderType) + } + if m.provider_key != nil { + fields = append(fields, authidentity.FieldProviderKey) + } + if m.provider_subject != nil { + fields = append(fields, authidentity.FieldProviderSubject) + } + if m.verified_at != nil { + fields = append(fields, authidentity.FieldVerifiedAt) + } + if m.issuer != nil { + fields = append(fields, authidentity.FieldIssuer) + } + if m.metadata != nil { + fields = append(fields, authidentity.FieldMetadata) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *AuthIdentityMutation) Field(name string) (ent.Value, bool) { + switch name { + case authidentity.FieldCreatedAt: + return m.CreatedAt() + case authidentity.FieldUpdatedAt: + return m.UpdatedAt() + case authidentity.FieldUserID: + return m.UserID() + case authidentity.FieldProviderType: + return m.ProviderType() + case authidentity.FieldProviderKey: + return m.ProviderKey() + case authidentity.FieldProviderSubject: + return m.ProviderSubject() + case authidentity.FieldVerifiedAt: + return m.VerifiedAt() + case authidentity.FieldIssuer: + return m.Issuer() + case authidentity.FieldMetadata: + return m.Metadata() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *AuthIdentityMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case authidentity.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case authidentity.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case authidentity.FieldUserID: + return m.OldUserID(ctx) + case authidentity.FieldProviderType: + return m.OldProviderType(ctx) + case authidentity.FieldProviderKey: + return m.OldProviderKey(ctx) + case authidentity.FieldProviderSubject: + return m.OldProviderSubject(ctx) + case authidentity.FieldVerifiedAt: + return m.OldVerifiedAt(ctx) + case authidentity.FieldIssuer: + return m.OldIssuer(ctx) + case authidentity.FieldMetadata: + return m.OldMetadata(ctx) + } + return nil, fmt.Errorf("unknown AuthIdentity field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthIdentityMutation) SetField(name string, value ent.Value) error { + switch name { + case authidentity.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case authidentity.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case authidentity.FieldUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserID(v) + return nil + case authidentity.FieldProviderType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderType(v) + return nil + case authidentity.FieldProviderKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderKey(v) + return nil + case authidentity.FieldProviderSubject: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderSubject(v) + return nil + case authidentity.FieldVerifiedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetVerifiedAt(v) + return nil + case authidentity.FieldIssuer: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIssuer(v) + return nil + case authidentity.FieldMetadata: + v, ok := value.(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMetadata(v) + return nil + } + return fmt.Errorf("unknown AuthIdentity field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *AuthIdentityMutation) AddedFields() []string { + var fields []string + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *AuthIdentityMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthIdentityMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown AuthIdentity numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *AuthIdentityMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(authidentity.FieldVerifiedAt) { + fields = append(fields, authidentity.FieldVerifiedAt) + } + if m.FieldCleared(authidentity.FieldIssuer) { + fields = append(fields, authidentity.FieldIssuer) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *AuthIdentityMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *AuthIdentityMutation) ClearField(name string) error { + switch name { + case authidentity.FieldVerifiedAt: + m.ClearVerifiedAt() + return nil + case authidentity.FieldIssuer: + m.ClearIssuer() + return nil + } + return fmt.Errorf("unknown AuthIdentity nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *AuthIdentityMutation) ResetField(name string) error { + switch name { + case authidentity.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case authidentity.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case authidentity.FieldUserID: + m.ResetUserID() + return nil + case authidentity.FieldProviderType: + m.ResetProviderType() + return nil + case authidentity.FieldProviderKey: + m.ResetProviderKey() + return nil + case authidentity.FieldProviderSubject: + m.ResetProviderSubject() + return nil + case authidentity.FieldVerifiedAt: + m.ResetVerifiedAt() + return nil + case authidentity.FieldIssuer: + m.ResetIssuer() + return nil + case authidentity.FieldMetadata: + m.ResetMetadata() + return nil + } + return fmt.Errorf("unknown AuthIdentity field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *AuthIdentityMutation) AddedEdges() []string { + edges := make([]string, 0, 3) + if m.user != nil { + edges = append(edges, authidentity.EdgeUser) + } + if m.channels != nil { + edges = append(edges, authidentity.EdgeChannels) + } + if m.adoption_decisions != nil { + edges = append(edges, authidentity.EdgeAdoptionDecisions) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *AuthIdentityMutation) AddedIDs(name string) []ent.Value { + switch name { + case authidentity.EdgeUser: + if id := m.user; id != nil { + return []ent.Value{*id} + } + case authidentity.EdgeChannels: + ids := make([]ent.Value, 0, len(m.channels)) + for id := range m.channels { + ids = append(ids, id) + } + return ids + case authidentity.EdgeAdoptionDecisions: + ids := make([]ent.Value, 0, len(m.adoption_decisions)) + for id := range m.adoption_decisions { + ids = append(ids, id) + } + return ids + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *AuthIdentityMutation) RemovedEdges() []string { + edges := make([]string, 0, 3) + if m.removedchannels != nil { + edges = append(edges, authidentity.EdgeChannels) + } + if m.removedadoption_decisions != nil { + edges = append(edges, authidentity.EdgeAdoptionDecisions) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *AuthIdentityMutation) RemovedIDs(name string) []ent.Value { + switch name { + case authidentity.EdgeChannels: + ids := make([]ent.Value, 0, len(m.removedchannels)) + for id := range m.removedchannels { + ids = append(ids, id) + } + return ids + case authidentity.EdgeAdoptionDecisions: + ids := make([]ent.Value, 0, len(m.removedadoption_decisions)) + for id := range m.removedadoption_decisions { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *AuthIdentityMutation) ClearedEdges() []string { + edges := make([]string, 0, 3) + if m.cleareduser { + edges = append(edges, authidentity.EdgeUser) + } + if m.clearedchannels { + edges = append(edges, authidentity.EdgeChannels) + } + if m.clearedadoption_decisions { + edges = append(edges, authidentity.EdgeAdoptionDecisions) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *AuthIdentityMutation) EdgeCleared(name string) bool { + switch name { + case authidentity.EdgeUser: + return m.cleareduser + case authidentity.EdgeChannels: + return m.clearedchannels + case authidentity.EdgeAdoptionDecisions: + return m.clearedadoption_decisions + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *AuthIdentityMutation) ClearEdge(name string) error { + switch name { + case authidentity.EdgeUser: + m.ClearUser() + return nil + } + return fmt.Errorf("unknown AuthIdentity unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *AuthIdentityMutation) ResetEdge(name string) error { + switch name { + case authidentity.EdgeUser: + m.ResetUser() + return nil + case authidentity.EdgeChannels: + m.ResetChannels() + return nil + case authidentity.EdgeAdoptionDecisions: + m.ResetAdoptionDecisions() + return nil + } + return fmt.Errorf("unknown AuthIdentity edge %s", name) +} + +// AuthIdentityChannelMutation represents an operation that mutates the AuthIdentityChannel nodes in the graph. +type AuthIdentityChannelMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + provider_type *string + provider_key *string + channel *string + channel_app_id *string + channel_subject *string + metadata *map[string]interface{} + clearedFields map[string]struct{} + identity *int64 + clearedidentity bool + done bool + oldValue func(context.Context) (*AuthIdentityChannel, error) + predicates []predicate.AuthIdentityChannel +} + +var _ ent.Mutation = (*AuthIdentityChannelMutation)(nil) + +// authidentitychannelOption allows management of the mutation configuration using functional options. +type authidentitychannelOption func(*AuthIdentityChannelMutation) + +// newAuthIdentityChannelMutation creates new mutation for the AuthIdentityChannel entity. +func newAuthIdentityChannelMutation(c config, op Op, opts ...authidentitychannelOption) *AuthIdentityChannelMutation { + m := &AuthIdentityChannelMutation{ + config: c, + op: op, + typ: TypeAuthIdentityChannel, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withAuthIdentityChannelID sets the ID field of the mutation. +func withAuthIdentityChannelID(id int64) authidentitychannelOption { + return func(m *AuthIdentityChannelMutation) { + var ( + err error + once sync.Once + value *AuthIdentityChannel + ) + m.oldValue = func(ctx context.Context) (*AuthIdentityChannel, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().AuthIdentityChannel.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withAuthIdentityChannel sets the old AuthIdentityChannel of the mutation. +func withAuthIdentityChannel(node *AuthIdentityChannel) authidentitychannelOption { + return func(m *AuthIdentityChannelMutation) { + m.oldValue = func(context.Context) (*AuthIdentityChannel, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m AuthIdentityChannelMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m AuthIdentityChannelMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *AuthIdentityChannelMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *AuthIdentityChannelMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().AuthIdentityChannel.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *AuthIdentityChannelMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *AuthIdentityChannelMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *AuthIdentityChannelMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *AuthIdentityChannelMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *AuthIdentityChannelMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *AuthIdentityChannelMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetIdentityID sets the "identity_id" field. +func (m *AuthIdentityChannelMutation) SetIdentityID(i int64) { + m.identity = &i +} + +// IdentityID returns the value of the "identity_id" field in the mutation. +func (m *AuthIdentityChannelMutation) IdentityID() (r int64, exists bool) { + v := m.identity + if v == nil { + return + } + return *v, true +} + +// OldIdentityID returns the old "identity_id" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldIdentityID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIdentityID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIdentityID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIdentityID: %w", err) + } + return oldValue.IdentityID, nil +} + +// ResetIdentityID resets all changes to the "identity_id" field. +func (m *AuthIdentityChannelMutation) ResetIdentityID() { + m.identity = nil +} + +// SetProviderType sets the "provider_type" field. +func (m *AuthIdentityChannelMutation) SetProviderType(s string) { + m.provider_type = &s +} + +// ProviderType returns the value of the "provider_type" field in the mutation. +func (m *AuthIdentityChannelMutation) ProviderType() (r string, exists bool) { + v := m.provider_type + if v == nil { + return + } + return *v, true +} + +// OldProviderType returns the old "provider_type" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldProviderType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderType: %w", err) + } + return oldValue.ProviderType, nil +} + +// ResetProviderType resets all changes to the "provider_type" field. +func (m *AuthIdentityChannelMutation) ResetProviderType() { + m.provider_type = nil +} + +// SetProviderKey sets the "provider_key" field. +func (m *AuthIdentityChannelMutation) SetProviderKey(s string) { + m.provider_key = &s +} + +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *AuthIdentityChannelMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key + if v == nil { + return + } + return *v, true +} + +// OldProviderKey returns the old "provider_key" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldProviderKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + } + return oldValue.ProviderKey, nil +} + +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *AuthIdentityChannelMutation) ResetProviderKey() { + m.provider_key = nil +} + +// SetChannel sets the "channel" field. +func (m *AuthIdentityChannelMutation) SetChannel(s string) { + m.channel = &s +} + +// Channel returns the value of the "channel" field in the mutation. +func (m *AuthIdentityChannelMutation) Channel() (r string, exists bool) { + v := m.channel + if v == nil { + return + } + return *v, true +} + +// OldChannel returns the old "channel" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldChannel(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldChannel is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldChannel requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldChannel: %w", err) + } + return oldValue.Channel, nil +} + +// ResetChannel resets all changes to the "channel" field. +func (m *AuthIdentityChannelMutation) ResetChannel() { + m.channel = nil +} + +// SetChannelAppID sets the "channel_app_id" field. +func (m *AuthIdentityChannelMutation) SetChannelAppID(s string) { + m.channel_app_id = &s +} + +// ChannelAppID returns the value of the "channel_app_id" field in the mutation. +func (m *AuthIdentityChannelMutation) ChannelAppID() (r string, exists bool) { + v := m.channel_app_id + if v == nil { + return + } + return *v, true +} + +// OldChannelAppID returns the old "channel_app_id" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldChannelAppID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldChannelAppID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldChannelAppID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldChannelAppID: %w", err) + } + return oldValue.ChannelAppID, nil +} + +// ResetChannelAppID resets all changes to the "channel_app_id" field. +func (m *AuthIdentityChannelMutation) ResetChannelAppID() { + m.channel_app_id = nil +} + +// SetChannelSubject sets the "channel_subject" field. +func (m *AuthIdentityChannelMutation) SetChannelSubject(s string) { + m.channel_subject = &s +} + +// ChannelSubject returns the value of the "channel_subject" field in the mutation. +func (m *AuthIdentityChannelMutation) ChannelSubject() (r string, exists bool) { + v := m.channel_subject + if v == nil { + return + } + return *v, true +} + +// OldChannelSubject returns the old "channel_subject" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldChannelSubject(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldChannelSubject is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldChannelSubject requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldChannelSubject: %w", err) + } + return oldValue.ChannelSubject, nil +} + +// ResetChannelSubject resets all changes to the "channel_subject" field. +func (m *AuthIdentityChannelMutation) ResetChannelSubject() { + m.channel_subject = nil +} + +// SetMetadata sets the "metadata" field. +func (m *AuthIdentityChannelMutation) SetMetadata(value map[string]interface{}) { + m.metadata = &value +} + +// Metadata returns the value of the "metadata" field in the mutation. +func (m *AuthIdentityChannelMutation) Metadata() (r map[string]interface{}, exists bool) { + v := m.metadata + if v == nil { + return + } + return *v, true +} + +// OldMetadata returns the old "metadata" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMetadata is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMetadata requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMetadata: %w", err) + } + return oldValue.Metadata, nil +} + +// ResetMetadata resets all changes to the "metadata" field. +func (m *AuthIdentityChannelMutation) ResetMetadata() { + m.metadata = nil +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (m *AuthIdentityChannelMutation) ClearIdentity() { + m.clearedidentity = true + m.clearedFields[authidentitychannel.FieldIdentityID] = struct{}{} +} + +// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared. +func (m *AuthIdentityChannelMutation) IdentityCleared() bool { + return m.clearedidentity +} + +// IdentityIDs returns the "identity" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// IdentityID instead. It exists only for internal usage by the builders. +func (m *AuthIdentityChannelMutation) IdentityIDs() (ids []int64) { + if id := m.identity; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetIdentity resets all changes to the "identity" edge. +func (m *AuthIdentityChannelMutation) ResetIdentity() { + m.identity = nil + m.clearedidentity = false +} + +// Where appends a list predicates to the AuthIdentityChannelMutation builder. +func (m *AuthIdentityChannelMutation) Where(ps ...predicate.AuthIdentityChannel) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the AuthIdentityChannelMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AuthIdentityChannelMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.AuthIdentityChannel, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *AuthIdentityChannelMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *AuthIdentityChannelMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (AuthIdentityChannel). +func (m *AuthIdentityChannelMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *AuthIdentityChannelMutation) Fields() []string { + fields := make([]string, 0, 9) + if m.created_at != nil { + fields = append(fields, authidentitychannel.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, authidentitychannel.FieldUpdatedAt) + } + if m.identity != nil { + fields = append(fields, authidentitychannel.FieldIdentityID) + } + if m.provider_type != nil { + fields = append(fields, authidentitychannel.FieldProviderType) + } + if m.provider_key != nil { + fields = append(fields, authidentitychannel.FieldProviderKey) + } + if m.channel != nil { + fields = append(fields, authidentitychannel.FieldChannel) + } + if m.channel_app_id != nil { + fields = append(fields, authidentitychannel.FieldChannelAppID) + } + if m.channel_subject != nil { + fields = append(fields, authidentitychannel.FieldChannelSubject) + } + if m.metadata != nil { + fields = append(fields, authidentitychannel.FieldMetadata) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *AuthIdentityChannelMutation) Field(name string) (ent.Value, bool) { + switch name { + case authidentitychannel.FieldCreatedAt: + return m.CreatedAt() + case authidentitychannel.FieldUpdatedAt: + return m.UpdatedAt() + case authidentitychannel.FieldIdentityID: + return m.IdentityID() + case authidentitychannel.FieldProviderType: + return m.ProviderType() + case authidentitychannel.FieldProviderKey: + return m.ProviderKey() + case authidentitychannel.FieldChannel: + return m.Channel() + case authidentitychannel.FieldChannelAppID: + return m.ChannelAppID() + case authidentitychannel.FieldChannelSubject: + return m.ChannelSubject() + case authidentitychannel.FieldMetadata: + return m.Metadata() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *AuthIdentityChannelMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case authidentitychannel.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case authidentitychannel.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case authidentitychannel.FieldIdentityID: + return m.OldIdentityID(ctx) + case authidentitychannel.FieldProviderType: + return m.OldProviderType(ctx) + case authidentitychannel.FieldProviderKey: + return m.OldProviderKey(ctx) + case authidentitychannel.FieldChannel: + return m.OldChannel(ctx) + case authidentitychannel.FieldChannelAppID: + return m.OldChannelAppID(ctx) + case authidentitychannel.FieldChannelSubject: + return m.OldChannelSubject(ctx) + case authidentitychannel.FieldMetadata: + return m.OldMetadata(ctx) + } + return nil, fmt.Errorf("unknown AuthIdentityChannel field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthIdentityChannelMutation) SetField(name string, value ent.Value) error { + switch name { + case authidentitychannel.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case authidentitychannel.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case authidentitychannel.FieldIdentityID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIdentityID(v) + return nil + case authidentitychannel.FieldProviderType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderType(v) + return nil + case authidentitychannel.FieldProviderKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderKey(v) + return nil + case authidentitychannel.FieldChannel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetChannel(v) + return nil + case authidentitychannel.FieldChannelAppID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetChannelAppID(v) + return nil + case authidentitychannel.FieldChannelSubject: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetChannelSubject(v) + return nil + case authidentitychannel.FieldMetadata: + v, ok := value.(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMetadata(v) + return nil + } + return fmt.Errorf("unknown AuthIdentityChannel field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *AuthIdentityChannelMutation) AddedFields() []string { + var fields []string + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *AuthIdentityChannelMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthIdentityChannelMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown AuthIdentityChannel numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *AuthIdentityChannelMutation) ClearedFields() []string { + return nil +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *AuthIdentityChannelMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *AuthIdentityChannelMutation) ClearField(name string) error { + return fmt.Errorf("unknown AuthIdentityChannel nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *AuthIdentityChannelMutation) ResetField(name string) error { + switch name { + case authidentitychannel.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case authidentitychannel.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case authidentitychannel.FieldIdentityID: + m.ResetIdentityID() + return nil + case authidentitychannel.FieldProviderType: + m.ResetProviderType() + return nil + case authidentitychannel.FieldProviderKey: + m.ResetProviderKey() + return nil + case authidentitychannel.FieldChannel: + m.ResetChannel() + return nil + case authidentitychannel.FieldChannelAppID: + m.ResetChannelAppID() + return nil + case authidentitychannel.FieldChannelSubject: + m.ResetChannelSubject() + return nil + case authidentitychannel.FieldMetadata: + m.ResetMetadata() + return nil + } + return fmt.Errorf("unknown AuthIdentityChannel field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *AuthIdentityChannelMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.identity != nil { + edges = append(edges, authidentitychannel.EdgeIdentity) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *AuthIdentityChannelMutation) AddedIDs(name string) []ent.Value { + switch name { + case authidentitychannel.EdgeIdentity: + if id := m.identity; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *AuthIdentityChannelMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *AuthIdentityChannelMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *AuthIdentityChannelMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedidentity { + edges = append(edges, authidentitychannel.EdgeIdentity) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *AuthIdentityChannelMutation) EdgeCleared(name string) bool { + switch name { + case authidentitychannel.EdgeIdentity: + return m.clearedidentity + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *AuthIdentityChannelMutation) ClearEdge(name string) error { + switch name { + case authidentitychannel.EdgeIdentity: + m.ClearIdentity() + return nil + } + return fmt.Errorf("unknown AuthIdentityChannel unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *AuthIdentityChannelMutation) ResetEdge(name string) error { + switch name { + case authidentitychannel.EdgeIdentity: + m.ResetIdentity() + return nil + } + return fmt.Errorf("unknown AuthIdentityChannel edge %s", name) +} + // ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph. type ErrorPassthroughRuleMutation struct { config @@ -12191,6 +14038,781 @@ func (m *IdempotencyRecordMutation) ResetEdge(name string) error { return fmt.Errorf("unknown IdempotencyRecord edge %s", name) } +// IdentityAdoptionDecisionMutation represents an operation that mutates the IdentityAdoptionDecision nodes in the graph. +type IdentityAdoptionDecisionMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + adopt_display_name *bool + adopt_avatar *bool + decided_at *time.Time + clearedFields map[string]struct{} + pending_auth_session *int64 + clearedpending_auth_session bool + identity *int64 + clearedidentity bool + done bool + oldValue func(context.Context) (*IdentityAdoptionDecision, error) + predicates []predicate.IdentityAdoptionDecision +} + +var _ ent.Mutation = (*IdentityAdoptionDecisionMutation)(nil) + +// identityadoptiondecisionOption allows management of the mutation configuration using functional options. +type identityadoptiondecisionOption func(*IdentityAdoptionDecisionMutation) + +// newIdentityAdoptionDecisionMutation creates new mutation for the IdentityAdoptionDecision entity. +func newIdentityAdoptionDecisionMutation(c config, op Op, opts ...identityadoptiondecisionOption) *IdentityAdoptionDecisionMutation { + m := &IdentityAdoptionDecisionMutation{ + config: c, + op: op, + typ: TypeIdentityAdoptionDecision, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withIdentityAdoptionDecisionID sets the ID field of the mutation. +func withIdentityAdoptionDecisionID(id int64) identityadoptiondecisionOption { + return func(m *IdentityAdoptionDecisionMutation) { + var ( + err error + once sync.Once + value *IdentityAdoptionDecision + ) + m.oldValue = func(ctx context.Context) (*IdentityAdoptionDecision, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().IdentityAdoptionDecision.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withIdentityAdoptionDecision sets the old IdentityAdoptionDecision of the mutation. +func withIdentityAdoptionDecision(node *IdentityAdoptionDecision) identityadoptiondecisionOption { + return func(m *IdentityAdoptionDecisionMutation) { + m.oldValue = func(context.Context) (*IdentityAdoptionDecision, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m IdentityAdoptionDecisionMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m IdentityAdoptionDecisionMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *IdentityAdoptionDecisionMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *IdentityAdoptionDecisionMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().IdentityAdoptionDecision.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *IdentityAdoptionDecisionMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *IdentityAdoptionDecisionMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *IdentityAdoptionDecisionMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *IdentityAdoptionDecisionMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (m *IdentityAdoptionDecisionMutation) SetPendingAuthSessionID(i int64) { + m.pending_auth_session = &i +} + +// PendingAuthSessionID returns the value of the "pending_auth_session_id" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionID() (r int64, exists bool) { + v := m.pending_auth_session + if v == nil { + return + } + return *v, true +} + +// OldPendingAuthSessionID returns the old "pending_auth_session_id" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldPendingAuthSessionID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPendingAuthSessionID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPendingAuthSessionID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPendingAuthSessionID: %w", err) + } + return oldValue.PendingAuthSessionID, nil +} + +// ResetPendingAuthSessionID resets all changes to the "pending_auth_session_id" field. +func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSessionID() { + m.pending_auth_session = nil +} + +// SetIdentityID sets the "identity_id" field. +func (m *IdentityAdoptionDecisionMutation) SetIdentityID(i int64) { + m.identity = &i +} + +// IdentityID returns the value of the "identity_id" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) IdentityID() (r int64, exists bool) { + v := m.identity + if v == nil { + return + } + return *v, true +} + +// OldIdentityID returns the old "identity_id" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldIdentityID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIdentityID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIdentityID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIdentityID: %w", err) + } + return oldValue.IdentityID, nil +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (m *IdentityAdoptionDecisionMutation) ClearIdentityID() { + m.identity = nil + m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{} +} + +// IdentityIDCleared returns if the "identity_id" field was cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) IdentityIDCleared() bool { + _, ok := m.clearedFields[identityadoptiondecision.FieldIdentityID] + return ok +} + +// ResetIdentityID resets all changes to the "identity_id" field. +func (m *IdentityAdoptionDecisionMutation) ResetIdentityID() { + m.identity = nil + delete(m.clearedFields, identityadoptiondecision.FieldIdentityID) +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (m *IdentityAdoptionDecisionMutation) SetAdoptDisplayName(b bool) { + m.adopt_display_name = &b +} + +// AdoptDisplayName returns the value of the "adopt_display_name" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) AdoptDisplayName() (r bool, exists bool) { + v := m.adopt_display_name + if v == nil { + return + } + return *v, true +} + +// OldAdoptDisplayName returns the old "adopt_display_name" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldAdoptDisplayName(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAdoptDisplayName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAdoptDisplayName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAdoptDisplayName: %w", err) + } + return oldValue.AdoptDisplayName, nil +} + +// ResetAdoptDisplayName resets all changes to the "adopt_display_name" field. +func (m *IdentityAdoptionDecisionMutation) ResetAdoptDisplayName() { + m.adopt_display_name = nil +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (m *IdentityAdoptionDecisionMutation) SetAdoptAvatar(b bool) { + m.adopt_avatar = &b +} + +// AdoptAvatar returns the value of the "adopt_avatar" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) AdoptAvatar() (r bool, exists bool) { + v := m.adopt_avatar + if v == nil { + return + } + return *v, true +} + +// OldAdoptAvatar returns the old "adopt_avatar" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldAdoptAvatar(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAdoptAvatar is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAdoptAvatar requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAdoptAvatar: %w", err) + } + return oldValue.AdoptAvatar, nil +} + +// ResetAdoptAvatar resets all changes to the "adopt_avatar" field. +func (m *IdentityAdoptionDecisionMutation) ResetAdoptAvatar() { + m.adopt_avatar = nil +} + +// SetDecidedAt sets the "decided_at" field. +func (m *IdentityAdoptionDecisionMutation) SetDecidedAt(t time.Time) { + m.decided_at = &t +} + +// DecidedAt returns the value of the "decided_at" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) DecidedAt() (r time.Time, exists bool) { + v := m.decided_at + if v == nil { + return + } + return *v, true +} + +// OldDecidedAt returns the old "decided_at" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldDecidedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDecidedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDecidedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDecidedAt: %w", err) + } + return oldValue.DecidedAt, nil +} + +// ResetDecidedAt resets all changes to the "decided_at" field. +func (m *IdentityAdoptionDecisionMutation) ResetDecidedAt() { + m.decided_at = nil +} + +// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity. +func (m *IdentityAdoptionDecisionMutation) ClearPendingAuthSession() { + m.clearedpending_auth_session = true + m.clearedFields[identityadoptiondecision.FieldPendingAuthSessionID] = struct{}{} +} + +// PendingAuthSessionCleared reports if the "pending_auth_session" edge to the PendingAuthSession entity was cleared. +func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionCleared() bool { + return m.clearedpending_auth_session +} + +// PendingAuthSessionIDs returns the "pending_auth_session" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// PendingAuthSessionID instead. It exists only for internal usage by the builders. +func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionIDs() (ids []int64) { + if id := m.pending_auth_session; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetPendingAuthSession resets all changes to the "pending_auth_session" edge. +func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSession() { + m.pending_auth_session = nil + m.clearedpending_auth_session = false +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (m *IdentityAdoptionDecisionMutation) ClearIdentity() { + m.clearedidentity = true + m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{} +} + +// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared. +func (m *IdentityAdoptionDecisionMutation) IdentityCleared() bool { + return m.IdentityIDCleared() || m.clearedidentity +} + +// IdentityIDs returns the "identity" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// IdentityID instead. It exists only for internal usage by the builders. +func (m *IdentityAdoptionDecisionMutation) IdentityIDs() (ids []int64) { + if id := m.identity; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetIdentity resets all changes to the "identity" edge. +func (m *IdentityAdoptionDecisionMutation) ResetIdentity() { + m.identity = nil + m.clearedidentity = false +} + +// Where appends a list predicates to the IdentityAdoptionDecisionMutation builder. +func (m *IdentityAdoptionDecisionMutation) Where(ps ...predicate.IdentityAdoptionDecision) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the IdentityAdoptionDecisionMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *IdentityAdoptionDecisionMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.IdentityAdoptionDecision, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *IdentityAdoptionDecisionMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *IdentityAdoptionDecisionMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (IdentityAdoptionDecision). +func (m *IdentityAdoptionDecisionMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *IdentityAdoptionDecisionMutation) Fields() []string { + fields := make([]string, 0, 7) + if m.created_at != nil { + fields = append(fields, identityadoptiondecision.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, identityadoptiondecision.FieldUpdatedAt) + } + if m.pending_auth_session != nil { + fields = append(fields, identityadoptiondecision.FieldPendingAuthSessionID) + } + if m.identity != nil { + fields = append(fields, identityadoptiondecision.FieldIdentityID) + } + if m.adopt_display_name != nil { + fields = append(fields, identityadoptiondecision.FieldAdoptDisplayName) + } + if m.adopt_avatar != nil { + fields = append(fields, identityadoptiondecision.FieldAdoptAvatar) + } + if m.decided_at != nil { + fields = append(fields, identityadoptiondecision.FieldDecidedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *IdentityAdoptionDecisionMutation) Field(name string) (ent.Value, bool) { + switch name { + case identityadoptiondecision.FieldCreatedAt: + return m.CreatedAt() + case identityadoptiondecision.FieldUpdatedAt: + return m.UpdatedAt() + case identityadoptiondecision.FieldPendingAuthSessionID: + return m.PendingAuthSessionID() + case identityadoptiondecision.FieldIdentityID: + return m.IdentityID() + case identityadoptiondecision.FieldAdoptDisplayName: + return m.AdoptDisplayName() + case identityadoptiondecision.FieldAdoptAvatar: + return m.AdoptAvatar() + case identityadoptiondecision.FieldDecidedAt: + return m.DecidedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *IdentityAdoptionDecisionMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case identityadoptiondecision.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case identityadoptiondecision.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case identityadoptiondecision.FieldPendingAuthSessionID: + return m.OldPendingAuthSessionID(ctx) + case identityadoptiondecision.FieldIdentityID: + return m.OldIdentityID(ctx) + case identityadoptiondecision.FieldAdoptDisplayName: + return m.OldAdoptDisplayName(ctx) + case identityadoptiondecision.FieldAdoptAvatar: + return m.OldAdoptAvatar(ctx) + case identityadoptiondecision.FieldDecidedAt: + return m.OldDecidedAt(ctx) + } + return nil, fmt.Errorf("unknown IdentityAdoptionDecision field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdentityAdoptionDecisionMutation) SetField(name string, value ent.Value) error { + switch name { + case identityadoptiondecision.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case identityadoptiondecision.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case identityadoptiondecision.FieldPendingAuthSessionID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPendingAuthSessionID(v) + return nil + case identityadoptiondecision.FieldIdentityID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIdentityID(v) + return nil + case identityadoptiondecision.FieldAdoptDisplayName: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAdoptDisplayName(v) + return nil + case identityadoptiondecision.FieldAdoptAvatar: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAdoptAvatar(v) + return nil + case identityadoptiondecision.FieldDecidedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDecidedAt(v) + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *IdentityAdoptionDecisionMutation) AddedFields() []string { + var fields []string + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdentityAdoptionDecisionMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown IdentityAdoptionDecision numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *IdentityAdoptionDecisionMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(identityadoptiondecision.FieldIdentityID) { + fields = append(fields, identityadoptiondecision.FieldIdentityID) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ClearField(name string) error { + switch name { + case identityadoptiondecision.FieldIdentityID: + m.ClearIdentityID() + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ResetField(name string) error { + switch name { + case identityadoptiondecision.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case identityadoptiondecision.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case identityadoptiondecision.FieldPendingAuthSessionID: + m.ResetPendingAuthSessionID() + return nil + case identityadoptiondecision.FieldIdentityID: + m.ResetIdentityID() + return nil + case identityadoptiondecision.FieldAdoptDisplayName: + m.ResetAdoptDisplayName() + return nil + case identityadoptiondecision.FieldAdoptAvatar: + m.ResetAdoptAvatar() + return nil + case identityadoptiondecision.FieldDecidedAt: + m.ResetDecidedAt() + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *IdentityAdoptionDecisionMutation) AddedEdges() []string { + edges := make([]string, 0, 2) + if m.pending_auth_session != nil { + edges = append(edges, identityadoptiondecision.EdgePendingAuthSession) + } + if m.identity != nil { + edges = append(edges, identityadoptiondecision.EdgeIdentity) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *IdentityAdoptionDecisionMutation) AddedIDs(name string) []ent.Value { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + if id := m.pending_auth_session; id != nil { + return []ent.Value{*id} + } + case identityadoptiondecision.EdgeIdentity: + if id := m.identity; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *IdentityAdoptionDecisionMutation) RemovedEdges() []string { + edges := make([]string, 0, 2) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *IdentityAdoptionDecisionMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) ClearedEdges() []string { + edges := make([]string, 0, 2) + if m.clearedpending_auth_session { + edges = append(edges, identityadoptiondecision.EdgePendingAuthSession) + } + if m.clearedidentity { + edges = append(edges, identityadoptiondecision.EdgeIdentity) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) EdgeCleared(name string) bool { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + return m.clearedpending_auth_session + case identityadoptiondecision.EdgeIdentity: + return m.clearedidentity + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ClearEdge(name string) error { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + m.ClearPendingAuthSession() + return nil + case identityadoptiondecision.EdgeIdentity: + m.ClearIdentity() + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ResetEdge(name string) error { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + m.ResetPendingAuthSession() + return nil + case identityadoptiondecision.EdgeIdentity: + m.ResetIdentity() + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision edge %s", name) +} + // PaymentAuditLogMutation represents an operation that mutates the PaymentAuditLog nodes in the graph. type PaymentAuditLogMutation struct { config @@ -16595,6 +19217,1645 @@ func (m *PaymentProviderInstanceMutation) ResetEdge(name string) error { return fmt.Errorf("unknown PaymentProviderInstance edge %s", name) } +// PendingAuthSessionMutation represents an operation that mutates the PendingAuthSession nodes in the graph. +type PendingAuthSessionMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + session_token *string + intent *string + provider_type *string + provider_key *string + provider_subject *string + redirect_to *string + resolved_email *string + registration_password_hash *string + upstream_identity_claims *map[string]interface{} + local_flow_state *map[string]interface{} + browser_session_key *string + completion_code_hash *string + completion_code_expires_at *time.Time + email_verified_at *time.Time + password_verified_at *time.Time + totp_verified_at *time.Time + expires_at *time.Time + consumed_at *time.Time + clearedFields map[string]struct{} + target_user *int64 + clearedtarget_user bool + adoption_decision *int64 + clearedadoption_decision bool + done bool + oldValue func(context.Context) (*PendingAuthSession, error) + predicates []predicate.PendingAuthSession +} + +var _ ent.Mutation = (*PendingAuthSessionMutation)(nil) + +// pendingauthsessionOption allows management of the mutation configuration using functional options. +type pendingauthsessionOption func(*PendingAuthSessionMutation) + +// newPendingAuthSessionMutation creates new mutation for the PendingAuthSession entity. +func newPendingAuthSessionMutation(c config, op Op, opts ...pendingauthsessionOption) *PendingAuthSessionMutation { + m := &PendingAuthSessionMutation{ + config: c, + op: op, + typ: TypePendingAuthSession, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withPendingAuthSessionID sets the ID field of the mutation. +func withPendingAuthSessionID(id int64) pendingauthsessionOption { + return func(m *PendingAuthSessionMutation) { + var ( + err error + once sync.Once + value *PendingAuthSession + ) + m.oldValue = func(ctx context.Context) (*PendingAuthSession, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().PendingAuthSession.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withPendingAuthSession sets the old PendingAuthSession of the mutation. +func withPendingAuthSession(node *PendingAuthSession) pendingauthsessionOption { + return func(m *PendingAuthSessionMutation) { + m.oldValue = func(context.Context) (*PendingAuthSession, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m PendingAuthSessionMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m PendingAuthSessionMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *PendingAuthSessionMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *PendingAuthSessionMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().PendingAuthSession.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *PendingAuthSessionMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *PendingAuthSessionMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *PendingAuthSessionMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *PendingAuthSessionMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *PendingAuthSessionMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *PendingAuthSessionMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetSessionToken sets the "session_token" field. +func (m *PendingAuthSessionMutation) SetSessionToken(s string) { + m.session_token = &s +} + +// SessionToken returns the value of the "session_token" field in the mutation. +func (m *PendingAuthSessionMutation) SessionToken() (r string, exists bool) { + v := m.session_token + if v == nil { + return + } + return *v, true +} + +// OldSessionToken returns the old "session_token" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldSessionToken(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSessionToken is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSessionToken requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSessionToken: %w", err) + } + return oldValue.SessionToken, nil +} + +// ResetSessionToken resets all changes to the "session_token" field. +func (m *PendingAuthSessionMutation) ResetSessionToken() { + m.session_token = nil +} + +// SetIntent sets the "intent" field. +func (m *PendingAuthSessionMutation) SetIntent(s string) { + m.intent = &s +} + +// Intent returns the value of the "intent" field in the mutation. +func (m *PendingAuthSessionMutation) Intent() (r string, exists bool) { + v := m.intent + if v == nil { + return + } + return *v, true +} + +// OldIntent returns the old "intent" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldIntent(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIntent is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIntent requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIntent: %w", err) + } + return oldValue.Intent, nil +} + +// ResetIntent resets all changes to the "intent" field. +func (m *PendingAuthSessionMutation) ResetIntent() { + m.intent = nil +} + +// SetProviderType sets the "provider_type" field. +func (m *PendingAuthSessionMutation) SetProviderType(s string) { + m.provider_type = &s +} + +// ProviderType returns the value of the "provider_type" field in the mutation. +func (m *PendingAuthSessionMutation) ProviderType() (r string, exists bool) { + v := m.provider_type + if v == nil { + return + } + return *v, true +} + +// OldProviderType returns the old "provider_type" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldProviderType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderType: %w", err) + } + return oldValue.ProviderType, nil +} + +// ResetProviderType resets all changes to the "provider_type" field. +func (m *PendingAuthSessionMutation) ResetProviderType() { + m.provider_type = nil +} + +// SetProviderKey sets the "provider_key" field. +func (m *PendingAuthSessionMutation) SetProviderKey(s string) { + m.provider_key = &s +} + +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *PendingAuthSessionMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key + if v == nil { + return + } + return *v, true +} + +// OldProviderKey returns the old "provider_key" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldProviderKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + } + return oldValue.ProviderKey, nil +} + +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *PendingAuthSessionMutation) ResetProviderKey() { + m.provider_key = nil +} + +// SetProviderSubject sets the "provider_subject" field. +func (m *PendingAuthSessionMutation) SetProviderSubject(s string) { + m.provider_subject = &s +} + +// ProviderSubject returns the value of the "provider_subject" field in the mutation. +func (m *PendingAuthSessionMutation) ProviderSubject() (r string, exists bool) { + v := m.provider_subject + if v == nil { + return + } + return *v, true +} + +// OldProviderSubject returns the old "provider_subject" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldProviderSubject(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderSubject requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err) + } + return oldValue.ProviderSubject, nil +} + +// ResetProviderSubject resets all changes to the "provider_subject" field. +func (m *PendingAuthSessionMutation) ResetProviderSubject() { + m.provider_subject = nil +} + +// SetTargetUserID sets the "target_user_id" field. +func (m *PendingAuthSessionMutation) SetTargetUserID(i int64) { + m.target_user = &i +} + +// TargetUserID returns the value of the "target_user_id" field in the mutation. +func (m *PendingAuthSessionMutation) TargetUserID() (r int64, exists bool) { + v := m.target_user + if v == nil { + return + } + return *v, true +} + +// OldTargetUserID returns the old "target_user_id" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldTargetUserID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTargetUserID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTargetUserID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTargetUserID: %w", err) + } + return oldValue.TargetUserID, nil +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (m *PendingAuthSessionMutation) ClearTargetUserID() { + m.target_user = nil + m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{} +} + +// TargetUserIDCleared returns if the "target_user_id" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) TargetUserIDCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldTargetUserID] + return ok +} + +// ResetTargetUserID resets all changes to the "target_user_id" field. +func (m *PendingAuthSessionMutation) ResetTargetUserID() { + m.target_user = nil + delete(m.clearedFields, pendingauthsession.FieldTargetUserID) +} + +// SetRedirectTo sets the "redirect_to" field. +func (m *PendingAuthSessionMutation) SetRedirectTo(s string) { + m.redirect_to = &s +} + +// RedirectTo returns the value of the "redirect_to" field in the mutation. +func (m *PendingAuthSessionMutation) RedirectTo() (r string, exists bool) { + v := m.redirect_to + if v == nil { + return + } + return *v, true +} + +// OldRedirectTo returns the old "redirect_to" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldRedirectTo(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRedirectTo is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRedirectTo requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRedirectTo: %w", err) + } + return oldValue.RedirectTo, nil +} + +// ResetRedirectTo resets all changes to the "redirect_to" field. +func (m *PendingAuthSessionMutation) ResetRedirectTo() { + m.redirect_to = nil +} + +// SetResolvedEmail sets the "resolved_email" field. +func (m *PendingAuthSessionMutation) SetResolvedEmail(s string) { + m.resolved_email = &s +} + +// ResolvedEmail returns the value of the "resolved_email" field in the mutation. +func (m *PendingAuthSessionMutation) ResolvedEmail() (r string, exists bool) { + v := m.resolved_email + if v == nil { + return + } + return *v, true +} + +// OldResolvedEmail returns the old "resolved_email" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldResolvedEmail(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResolvedEmail is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResolvedEmail requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResolvedEmail: %w", err) + } + return oldValue.ResolvedEmail, nil +} + +// ResetResolvedEmail resets all changes to the "resolved_email" field. +func (m *PendingAuthSessionMutation) ResetResolvedEmail() { + m.resolved_email = nil +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (m *PendingAuthSessionMutation) SetRegistrationPasswordHash(s string) { + m.registration_password_hash = &s +} + +// RegistrationPasswordHash returns the value of the "registration_password_hash" field in the mutation. +func (m *PendingAuthSessionMutation) RegistrationPasswordHash() (r string, exists bool) { + v := m.registration_password_hash + if v == nil { + return + } + return *v, true +} + +// OldRegistrationPasswordHash returns the old "registration_password_hash" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldRegistrationPasswordHash(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRegistrationPasswordHash is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRegistrationPasswordHash requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRegistrationPasswordHash: %w", err) + } + return oldValue.RegistrationPasswordHash, nil +} + +// ResetRegistrationPasswordHash resets all changes to the "registration_password_hash" field. +func (m *PendingAuthSessionMutation) ResetRegistrationPasswordHash() { + m.registration_password_hash = nil +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (m *PendingAuthSessionMutation) SetUpstreamIdentityClaims(value map[string]interface{}) { + m.upstream_identity_claims = &value +} + +// UpstreamIdentityClaims returns the value of the "upstream_identity_claims" field in the mutation. +func (m *PendingAuthSessionMutation) UpstreamIdentityClaims() (r map[string]interface{}, exists bool) { + v := m.upstream_identity_claims + if v == nil { + return + } + return *v, true +} + +// OldUpstreamIdentityClaims returns the old "upstream_identity_claims" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldUpstreamIdentityClaims(ctx context.Context) (v map[string]interface{}, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpstreamIdentityClaims is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpstreamIdentityClaims requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpstreamIdentityClaims: %w", err) + } + return oldValue.UpstreamIdentityClaims, nil +} + +// ResetUpstreamIdentityClaims resets all changes to the "upstream_identity_claims" field. +func (m *PendingAuthSessionMutation) ResetUpstreamIdentityClaims() { + m.upstream_identity_claims = nil +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (m *PendingAuthSessionMutation) SetLocalFlowState(value map[string]interface{}) { + m.local_flow_state = &value +} + +// LocalFlowState returns the value of the "local_flow_state" field in the mutation. +func (m *PendingAuthSessionMutation) LocalFlowState() (r map[string]interface{}, exists bool) { + v := m.local_flow_state + if v == nil { + return + } + return *v, true +} + +// OldLocalFlowState returns the old "local_flow_state" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldLocalFlowState(ctx context.Context) (v map[string]interface{}, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLocalFlowState is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLocalFlowState requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLocalFlowState: %w", err) + } + return oldValue.LocalFlowState, nil +} + +// ResetLocalFlowState resets all changes to the "local_flow_state" field. +func (m *PendingAuthSessionMutation) ResetLocalFlowState() { + m.local_flow_state = nil +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (m *PendingAuthSessionMutation) SetBrowserSessionKey(s string) { + m.browser_session_key = &s +} + +// BrowserSessionKey returns the value of the "browser_session_key" field in the mutation. +func (m *PendingAuthSessionMutation) BrowserSessionKey() (r string, exists bool) { + v := m.browser_session_key + if v == nil { + return + } + return *v, true +} + +// OldBrowserSessionKey returns the old "browser_session_key" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldBrowserSessionKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBrowserSessionKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBrowserSessionKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBrowserSessionKey: %w", err) + } + return oldValue.BrowserSessionKey, nil +} + +// ResetBrowserSessionKey resets all changes to the "browser_session_key" field. +func (m *PendingAuthSessionMutation) ResetBrowserSessionKey() { + m.browser_session_key = nil +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (m *PendingAuthSessionMutation) SetCompletionCodeHash(s string) { + m.completion_code_hash = &s +} + +// CompletionCodeHash returns the value of the "completion_code_hash" field in the mutation. +func (m *PendingAuthSessionMutation) CompletionCodeHash() (r string, exists bool) { + v := m.completion_code_hash + if v == nil { + return + } + return *v, true +} + +// OldCompletionCodeHash returns the old "completion_code_hash" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldCompletionCodeHash(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCompletionCodeHash is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCompletionCodeHash requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCompletionCodeHash: %w", err) + } + return oldValue.CompletionCodeHash, nil +} + +// ResetCompletionCodeHash resets all changes to the "completion_code_hash" field. +func (m *PendingAuthSessionMutation) ResetCompletionCodeHash() { + m.completion_code_hash = nil +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (m *PendingAuthSessionMutation) SetCompletionCodeExpiresAt(t time.Time) { + m.completion_code_expires_at = &t +} + +// CompletionCodeExpiresAt returns the value of the "completion_code_expires_at" field in the mutation. +func (m *PendingAuthSessionMutation) CompletionCodeExpiresAt() (r time.Time, exists bool) { + v := m.completion_code_expires_at + if v == nil { + return + } + return *v, true +} + +// OldCompletionCodeExpiresAt returns the old "completion_code_expires_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldCompletionCodeExpiresAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCompletionCodeExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCompletionCodeExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCompletionCodeExpiresAt: %w", err) + } + return oldValue.CompletionCodeExpiresAt, nil +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (m *PendingAuthSessionMutation) ClearCompletionCodeExpiresAt() { + m.completion_code_expires_at = nil + m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt] = struct{}{} +} + +// CompletionCodeExpiresAtCleared returns if the "completion_code_expires_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) CompletionCodeExpiresAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt] + return ok +} + +// ResetCompletionCodeExpiresAt resets all changes to the "completion_code_expires_at" field. +func (m *PendingAuthSessionMutation) ResetCompletionCodeExpiresAt() { + m.completion_code_expires_at = nil + delete(m.clearedFields, pendingauthsession.FieldCompletionCodeExpiresAt) +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (m *PendingAuthSessionMutation) SetEmailVerifiedAt(t time.Time) { + m.email_verified_at = &t +} + +// EmailVerifiedAt returns the value of the "email_verified_at" field in the mutation. +func (m *PendingAuthSessionMutation) EmailVerifiedAt() (r time.Time, exists bool) { + v := m.email_verified_at + if v == nil { + return + } + return *v, true +} + +// OldEmailVerifiedAt returns the old "email_verified_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldEmailVerifiedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldEmailVerifiedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldEmailVerifiedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldEmailVerifiedAt: %w", err) + } + return oldValue.EmailVerifiedAt, nil +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (m *PendingAuthSessionMutation) ClearEmailVerifiedAt() { + m.email_verified_at = nil + m.clearedFields[pendingauthsession.FieldEmailVerifiedAt] = struct{}{} +} + +// EmailVerifiedAtCleared returns if the "email_verified_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) EmailVerifiedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldEmailVerifiedAt] + return ok +} + +// ResetEmailVerifiedAt resets all changes to the "email_verified_at" field. +func (m *PendingAuthSessionMutation) ResetEmailVerifiedAt() { + m.email_verified_at = nil + delete(m.clearedFields, pendingauthsession.FieldEmailVerifiedAt) +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (m *PendingAuthSessionMutation) SetPasswordVerifiedAt(t time.Time) { + m.password_verified_at = &t +} + +// PasswordVerifiedAt returns the value of the "password_verified_at" field in the mutation. +func (m *PendingAuthSessionMutation) PasswordVerifiedAt() (r time.Time, exists bool) { + v := m.password_verified_at + if v == nil { + return + } + return *v, true +} + +// OldPasswordVerifiedAt returns the old "password_verified_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldPasswordVerifiedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPasswordVerifiedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPasswordVerifiedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPasswordVerifiedAt: %w", err) + } + return oldValue.PasswordVerifiedAt, nil +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (m *PendingAuthSessionMutation) ClearPasswordVerifiedAt() { + m.password_verified_at = nil + m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt] = struct{}{} +} + +// PasswordVerifiedAtCleared returns if the "password_verified_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) PasswordVerifiedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt] + return ok +} + +// ResetPasswordVerifiedAt resets all changes to the "password_verified_at" field. +func (m *PendingAuthSessionMutation) ResetPasswordVerifiedAt() { + m.password_verified_at = nil + delete(m.clearedFields, pendingauthsession.FieldPasswordVerifiedAt) +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (m *PendingAuthSessionMutation) SetTotpVerifiedAt(t time.Time) { + m.totp_verified_at = &t +} + +// TotpVerifiedAt returns the value of the "totp_verified_at" field in the mutation. +func (m *PendingAuthSessionMutation) TotpVerifiedAt() (r time.Time, exists bool) { + v := m.totp_verified_at + if v == nil { + return + } + return *v, true +} + +// OldTotpVerifiedAt returns the old "totp_verified_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldTotpVerifiedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTotpVerifiedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTotpVerifiedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTotpVerifiedAt: %w", err) + } + return oldValue.TotpVerifiedAt, nil +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (m *PendingAuthSessionMutation) ClearTotpVerifiedAt() { + m.totp_verified_at = nil + m.clearedFields[pendingauthsession.FieldTotpVerifiedAt] = struct{}{} +} + +// TotpVerifiedAtCleared returns if the "totp_verified_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) TotpVerifiedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldTotpVerifiedAt] + return ok +} + +// ResetTotpVerifiedAt resets all changes to the "totp_verified_at" field. +func (m *PendingAuthSessionMutation) ResetTotpVerifiedAt() { + m.totp_verified_at = nil + delete(m.clearedFields, pendingauthsession.FieldTotpVerifiedAt) +} + +// SetExpiresAt sets the "expires_at" field. +func (m *PendingAuthSessionMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *PendingAuthSessionMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at + if v == nil { + return + } + return *v, true +} + +// OldExpiresAt returns the old "expires_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + } + return oldValue.ExpiresAt, nil +} + +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *PendingAuthSessionMutation) ResetExpiresAt() { + m.expires_at = nil +} + +// SetConsumedAt sets the "consumed_at" field. +func (m *PendingAuthSessionMutation) SetConsumedAt(t time.Time) { + m.consumed_at = &t +} + +// ConsumedAt returns the value of the "consumed_at" field in the mutation. +func (m *PendingAuthSessionMutation) ConsumedAt() (r time.Time, exists bool) { + v := m.consumed_at + if v == nil { + return + } + return *v, true +} + +// OldConsumedAt returns the old "consumed_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldConsumedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldConsumedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldConsumedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldConsumedAt: %w", err) + } + return oldValue.ConsumedAt, nil +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (m *PendingAuthSessionMutation) ClearConsumedAt() { + m.consumed_at = nil + m.clearedFields[pendingauthsession.FieldConsumedAt] = struct{}{} +} + +// ConsumedAtCleared returns if the "consumed_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) ConsumedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldConsumedAt] + return ok +} + +// ResetConsumedAt resets all changes to the "consumed_at" field. +func (m *PendingAuthSessionMutation) ResetConsumedAt() { + m.consumed_at = nil + delete(m.clearedFields, pendingauthsession.FieldConsumedAt) +} + +// ClearTargetUser clears the "target_user" edge to the User entity. +func (m *PendingAuthSessionMutation) ClearTargetUser() { + m.clearedtarget_user = true + m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{} +} + +// TargetUserCleared reports if the "target_user" edge to the User entity was cleared. +func (m *PendingAuthSessionMutation) TargetUserCleared() bool { + return m.TargetUserIDCleared() || m.clearedtarget_user +} + +// TargetUserIDs returns the "target_user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// TargetUserID instead. It exists only for internal usage by the builders. +func (m *PendingAuthSessionMutation) TargetUserIDs() (ids []int64) { + if id := m.target_user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetTargetUser resets all changes to the "target_user" edge. +func (m *PendingAuthSessionMutation) ResetTargetUser() { + m.target_user = nil + m.clearedtarget_user = false +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by id. +func (m *PendingAuthSessionMutation) SetAdoptionDecisionID(id int64) { + m.adoption_decision = &id +} + +// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (m *PendingAuthSessionMutation) ClearAdoptionDecision() { + m.clearedadoption_decision = true +} + +// AdoptionDecisionCleared reports if the "adoption_decision" edge to the IdentityAdoptionDecision entity was cleared. +func (m *PendingAuthSessionMutation) AdoptionDecisionCleared() bool { + return m.clearedadoption_decision +} + +// AdoptionDecisionID returns the "adoption_decision" edge ID in the mutation. +func (m *PendingAuthSessionMutation) AdoptionDecisionID() (id int64, exists bool) { + if m.adoption_decision != nil { + return *m.adoption_decision, true + } + return +} + +// AdoptionDecisionIDs returns the "adoption_decision" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// AdoptionDecisionID instead. It exists only for internal usage by the builders. +func (m *PendingAuthSessionMutation) AdoptionDecisionIDs() (ids []int64) { + if id := m.adoption_decision; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetAdoptionDecision resets all changes to the "adoption_decision" edge. +func (m *PendingAuthSessionMutation) ResetAdoptionDecision() { + m.adoption_decision = nil + m.clearedadoption_decision = false +} + +// Where appends a list predicates to the PendingAuthSessionMutation builder. +func (m *PendingAuthSessionMutation) Where(ps ...predicate.PendingAuthSession) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the PendingAuthSessionMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *PendingAuthSessionMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.PendingAuthSession, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *PendingAuthSessionMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *PendingAuthSessionMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (PendingAuthSession). +func (m *PendingAuthSessionMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *PendingAuthSessionMutation) Fields() []string { + fields := make([]string, 0, 21) + if m.created_at != nil { + fields = append(fields, pendingauthsession.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, pendingauthsession.FieldUpdatedAt) + } + if m.session_token != nil { + fields = append(fields, pendingauthsession.FieldSessionToken) + } + if m.intent != nil { + fields = append(fields, pendingauthsession.FieldIntent) + } + if m.provider_type != nil { + fields = append(fields, pendingauthsession.FieldProviderType) + } + if m.provider_key != nil { + fields = append(fields, pendingauthsession.FieldProviderKey) + } + if m.provider_subject != nil { + fields = append(fields, pendingauthsession.FieldProviderSubject) + } + if m.target_user != nil { + fields = append(fields, pendingauthsession.FieldTargetUserID) + } + if m.redirect_to != nil { + fields = append(fields, pendingauthsession.FieldRedirectTo) + } + if m.resolved_email != nil { + fields = append(fields, pendingauthsession.FieldResolvedEmail) + } + if m.registration_password_hash != nil { + fields = append(fields, pendingauthsession.FieldRegistrationPasswordHash) + } + if m.upstream_identity_claims != nil { + fields = append(fields, pendingauthsession.FieldUpstreamIdentityClaims) + } + if m.local_flow_state != nil { + fields = append(fields, pendingauthsession.FieldLocalFlowState) + } + if m.browser_session_key != nil { + fields = append(fields, pendingauthsession.FieldBrowserSessionKey) + } + if m.completion_code_hash != nil { + fields = append(fields, pendingauthsession.FieldCompletionCodeHash) + } + if m.completion_code_expires_at != nil { + fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt) + } + if m.email_verified_at != nil { + fields = append(fields, pendingauthsession.FieldEmailVerifiedAt) + } + if m.password_verified_at != nil { + fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt) + } + if m.totp_verified_at != nil { + fields = append(fields, pendingauthsession.FieldTotpVerifiedAt) + } + if m.expires_at != nil { + fields = append(fields, pendingauthsession.FieldExpiresAt) + } + if m.consumed_at != nil { + fields = append(fields, pendingauthsession.FieldConsumedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *PendingAuthSessionMutation) Field(name string) (ent.Value, bool) { + switch name { + case pendingauthsession.FieldCreatedAt: + return m.CreatedAt() + case pendingauthsession.FieldUpdatedAt: + return m.UpdatedAt() + case pendingauthsession.FieldSessionToken: + return m.SessionToken() + case pendingauthsession.FieldIntent: + return m.Intent() + case pendingauthsession.FieldProviderType: + return m.ProviderType() + case pendingauthsession.FieldProviderKey: + return m.ProviderKey() + case pendingauthsession.FieldProviderSubject: + return m.ProviderSubject() + case pendingauthsession.FieldTargetUserID: + return m.TargetUserID() + case pendingauthsession.FieldRedirectTo: + return m.RedirectTo() + case pendingauthsession.FieldResolvedEmail: + return m.ResolvedEmail() + case pendingauthsession.FieldRegistrationPasswordHash: + return m.RegistrationPasswordHash() + case pendingauthsession.FieldUpstreamIdentityClaims: + return m.UpstreamIdentityClaims() + case pendingauthsession.FieldLocalFlowState: + return m.LocalFlowState() + case pendingauthsession.FieldBrowserSessionKey: + return m.BrowserSessionKey() + case pendingauthsession.FieldCompletionCodeHash: + return m.CompletionCodeHash() + case pendingauthsession.FieldCompletionCodeExpiresAt: + return m.CompletionCodeExpiresAt() + case pendingauthsession.FieldEmailVerifiedAt: + return m.EmailVerifiedAt() + case pendingauthsession.FieldPasswordVerifiedAt: + return m.PasswordVerifiedAt() + case pendingauthsession.FieldTotpVerifiedAt: + return m.TotpVerifiedAt() + case pendingauthsession.FieldExpiresAt: + return m.ExpiresAt() + case pendingauthsession.FieldConsumedAt: + return m.ConsumedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *PendingAuthSessionMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case pendingauthsession.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case pendingauthsession.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case pendingauthsession.FieldSessionToken: + return m.OldSessionToken(ctx) + case pendingauthsession.FieldIntent: + return m.OldIntent(ctx) + case pendingauthsession.FieldProviderType: + return m.OldProviderType(ctx) + case pendingauthsession.FieldProviderKey: + return m.OldProviderKey(ctx) + case pendingauthsession.FieldProviderSubject: + return m.OldProviderSubject(ctx) + case pendingauthsession.FieldTargetUserID: + return m.OldTargetUserID(ctx) + case pendingauthsession.FieldRedirectTo: + return m.OldRedirectTo(ctx) + case pendingauthsession.FieldResolvedEmail: + return m.OldResolvedEmail(ctx) + case pendingauthsession.FieldRegistrationPasswordHash: + return m.OldRegistrationPasswordHash(ctx) + case pendingauthsession.FieldUpstreamIdentityClaims: + return m.OldUpstreamIdentityClaims(ctx) + case pendingauthsession.FieldLocalFlowState: + return m.OldLocalFlowState(ctx) + case pendingauthsession.FieldBrowserSessionKey: + return m.OldBrowserSessionKey(ctx) + case pendingauthsession.FieldCompletionCodeHash: + return m.OldCompletionCodeHash(ctx) + case pendingauthsession.FieldCompletionCodeExpiresAt: + return m.OldCompletionCodeExpiresAt(ctx) + case pendingauthsession.FieldEmailVerifiedAt: + return m.OldEmailVerifiedAt(ctx) + case pendingauthsession.FieldPasswordVerifiedAt: + return m.OldPasswordVerifiedAt(ctx) + case pendingauthsession.FieldTotpVerifiedAt: + return m.OldTotpVerifiedAt(ctx) + case pendingauthsession.FieldExpiresAt: + return m.OldExpiresAt(ctx) + case pendingauthsession.FieldConsumedAt: + return m.OldConsumedAt(ctx) + } + return nil, fmt.Errorf("unknown PendingAuthSession field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PendingAuthSessionMutation) SetField(name string, value ent.Value) error { + switch name { + case pendingauthsession.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case pendingauthsession.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case pendingauthsession.FieldSessionToken: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSessionToken(v) + return nil + case pendingauthsession.FieldIntent: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIntent(v) + return nil + case pendingauthsession.FieldProviderType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderType(v) + return nil + case pendingauthsession.FieldProviderKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderKey(v) + return nil + case pendingauthsession.FieldProviderSubject: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderSubject(v) + return nil + case pendingauthsession.FieldTargetUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTargetUserID(v) + return nil + case pendingauthsession.FieldRedirectTo: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRedirectTo(v) + return nil + case pendingauthsession.FieldResolvedEmail: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResolvedEmail(v) + return nil + case pendingauthsession.FieldRegistrationPasswordHash: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRegistrationPasswordHash(v) + return nil + case pendingauthsession.FieldUpstreamIdentityClaims: + v, ok := value.(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpstreamIdentityClaims(v) + return nil + case pendingauthsession.FieldLocalFlowState: + v, ok := value.(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLocalFlowState(v) + return nil + case pendingauthsession.FieldBrowserSessionKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBrowserSessionKey(v) + return nil + case pendingauthsession.FieldCompletionCodeHash: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCompletionCodeHash(v) + return nil + case pendingauthsession.FieldCompletionCodeExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCompletionCodeExpiresAt(v) + return nil + case pendingauthsession.FieldEmailVerifiedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetEmailVerifiedAt(v) + return nil + case pendingauthsession.FieldPasswordVerifiedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPasswordVerifiedAt(v) + return nil + case pendingauthsession.FieldTotpVerifiedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotpVerifiedAt(v) + return nil + case pendingauthsession.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil + case pendingauthsession.FieldConsumedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetConsumedAt(v) + return nil + } + return fmt.Errorf("unknown PendingAuthSession field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *PendingAuthSessionMutation) AddedFields() []string { + var fields []string + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *PendingAuthSessionMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PendingAuthSessionMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown PendingAuthSession numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *PendingAuthSessionMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(pendingauthsession.FieldTargetUserID) { + fields = append(fields, pendingauthsession.FieldTargetUserID) + } + if m.FieldCleared(pendingauthsession.FieldCompletionCodeExpiresAt) { + fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt) + } + if m.FieldCleared(pendingauthsession.FieldEmailVerifiedAt) { + fields = append(fields, pendingauthsession.FieldEmailVerifiedAt) + } + if m.FieldCleared(pendingauthsession.FieldPasswordVerifiedAt) { + fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt) + } + if m.FieldCleared(pendingauthsession.FieldTotpVerifiedAt) { + fields = append(fields, pendingauthsession.FieldTotpVerifiedAt) + } + if m.FieldCleared(pendingauthsession.FieldConsumedAt) { + fields = append(fields, pendingauthsession.FieldConsumedAt) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *PendingAuthSessionMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *PendingAuthSessionMutation) ClearField(name string) error { + switch name { + case pendingauthsession.FieldTargetUserID: + m.ClearTargetUserID() + return nil + case pendingauthsession.FieldCompletionCodeExpiresAt: + m.ClearCompletionCodeExpiresAt() + return nil + case pendingauthsession.FieldEmailVerifiedAt: + m.ClearEmailVerifiedAt() + return nil + case pendingauthsession.FieldPasswordVerifiedAt: + m.ClearPasswordVerifiedAt() + return nil + case pendingauthsession.FieldTotpVerifiedAt: + m.ClearTotpVerifiedAt() + return nil + case pendingauthsession.FieldConsumedAt: + m.ClearConsumedAt() + return nil + } + return fmt.Errorf("unknown PendingAuthSession nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *PendingAuthSessionMutation) ResetField(name string) error { + switch name { + case pendingauthsession.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case pendingauthsession.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case pendingauthsession.FieldSessionToken: + m.ResetSessionToken() + return nil + case pendingauthsession.FieldIntent: + m.ResetIntent() + return nil + case pendingauthsession.FieldProviderType: + m.ResetProviderType() + return nil + case pendingauthsession.FieldProviderKey: + m.ResetProviderKey() + return nil + case pendingauthsession.FieldProviderSubject: + m.ResetProviderSubject() + return nil + case pendingauthsession.FieldTargetUserID: + m.ResetTargetUserID() + return nil + case pendingauthsession.FieldRedirectTo: + m.ResetRedirectTo() + return nil + case pendingauthsession.FieldResolvedEmail: + m.ResetResolvedEmail() + return nil + case pendingauthsession.FieldRegistrationPasswordHash: + m.ResetRegistrationPasswordHash() + return nil + case pendingauthsession.FieldUpstreamIdentityClaims: + m.ResetUpstreamIdentityClaims() + return nil + case pendingauthsession.FieldLocalFlowState: + m.ResetLocalFlowState() + return nil + case pendingauthsession.FieldBrowserSessionKey: + m.ResetBrowserSessionKey() + return nil + case pendingauthsession.FieldCompletionCodeHash: + m.ResetCompletionCodeHash() + return nil + case pendingauthsession.FieldCompletionCodeExpiresAt: + m.ResetCompletionCodeExpiresAt() + return nil + case pendingauthsession.FieldEmailVerifiedAt: + m.ResetEmailVerifiedAt() + return nil + case pendingauthsession.FieldPasswordVerifiedAt: + m.ResetPasswordVerifiedAt() + return nil + case pendingauthsession.FieldTotpVerifiedAt: + m.ResetTotpVerifiedAt() + return nil + case pendingauthsession.FieldExpiresAt: + m.ResetExpiresAt() + return nil + case pendingauthsession.FieldConsumedAt: + m.ResetConsumedAt() + return nil + } + return fmt.Errorf("unknown PendingAuthSession field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *PendingAuthSessionMutation) AddedEdges() []string { + edges := make([]string, 0, 2) + if m.target_user != nil { + edges = append(edges, pendingauthsession.EdgeTargetUser) + } + if m.adoption_decision != nil { + edges = append(edges, pendingauthsession.EdgeAdoptionDecision) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *PendingAuthSessionMutation) AddedIDs(name string) []ent.Value { + switch name { + case pendingauthsession.EdgeTargetUser: + if id := m.target_user; id != nil { + return []ent.Value{*id} + } + case pendingauthsession.EdgeAdoptionDecision: + if id := m.adoption_decision; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *PendingAuthSessionMutation) RemovedEdges() []string { + edges := make([]string, 0, 2) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *PendingAuthSessionMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *PendingAuthSessionMutation) ClearedEdges() []string { + edges := make([]string, 0, 2) + if m.clearedtarget_user { + edges = append(edges, pendingauthsession.EdgeTargetUser) + } + if m.clearedadoption_decision { + edges = append(edges, pendingauthsession.EdgeAdoptionDecision) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *PendingAuthSessionMutation) EdgeCleared(name string) bool { + switch name { + case pendingauthsession.EdgeTargetUser: + return m.clearedtarget_user + case pendingauthsession.EdgeAdoptionDecision: + return m.clearedadoption_decision + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *PendingAuthSessionMutation) ClearEdge(name string) error { + switch name { + case pendingauthsession.EdgeTargetUser: + m.ClearTargetUser() + return nil + case pendingauthsession.EdgeAdoptionDecision: + m.ClearAdoptionDecision() + return nil + } + return fmt.Errorf("unknown PendingAuthSession unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *PendingAuthSessionMutation) ResetEdge(name string) error { + switch name { + case pendingauthsession.EdgeTargetUser: + m.ResetTargetUser() + return nil + case pendingauthsession.EdgeAdoptionDecision: + m.ResetAdoptionDecision() + return nil + } + return fmt.Errorf("unknown PendingAuthSession edge %s", name) +} + // PromoCodeMutation represents an operation that mutates the PromoCode nodes in the graph. type PromoCodeMutation struct { config @@ -28264,6 +32525,9 @@ type UserMutation struct { totp_secret_encrypted *string totp_enabled *bool totp_enabled_at *time.Time + signup_source *string + last_login_at *time.Time + last_active_at *time.Time balance_notify_enabled *bool balance_notify_threshold_type *string balance_notify_threshold *float64 @@ -28302,6 +32566,12 @@ type UserMutation struct { payment_orders map[int64]struct{} removedpayment_orders map[int64]struct{} clearedpayment_orders bool + auth_identities map[int64]struct{} + removedauth_identities map[int64]struct{} + clearedauth_identities bool + pending_auth_sessions map[int64]struct{} + removedpending_auth_sessions map[int64]struct{} + clearedpending_auth_sessions bool done bool oldValue func(context.Context) (*User, error) predicates []predicate.User @@ -28988,6 +33258,140 @@ func (m *UserMutation) ResetTotpEnabledAt() { delete(m.clearedFields, user.FieldTotpEnabledAt) } +// SetSignupSource sets the "signup_source" field. +func (m *UserMutation) SetSignupSource(s string) { + m.signup_source = &s +} + +// SignupSource returns the value of the "signup_source" field in the mutation. +func (m *UserMutation) SignupSource() (r string, exists bool) { + v := m.signup_source + if v == nil { + return + } + return *v, true +} + +// OldSignupSource returns the old "signup_source" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldSignupSource(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSignupSource is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSignupSource requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSignupSource: %w", err) + } + return oldValue.SignupSource, nil +} + +// ResetSignupSource resets all changes to the "signup_source" field. +func (m *UserMutation) ResetSignupSource() { + m.signup_source = nil +} + +// SetLastLoginAt sets the "last_login_at" field. +func (m *UserMutation) SetLastLoginAt(t time.Time) { + m.last_login_at = &t +} + +// LastLoginAt returns the value of the "last_login_at" field in the mutation. +func (m *UserMutation) LastLoginAt() (r time.Time, exists bool) { + v := m.last_login_at + if v == nil { + return + } + return *v, true +} + +// OldLastLoginAt returns the old "last_login_at" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldLastLoginAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastLoginAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastLoginAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastLoginAt: %w", err) + } + return oldValue.LastLoginAt, nil +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (m *UserMutation) ClearLastLoginAt() { + m.last_login_at = nil + m.clearedFields[user.FieldLastLoginAt] = struct{}{} +} + +// LastLoginAtCleared returns if the "last_login_at" field was cleared in this mutation. +func (m *UserMutation) LastLoginAtCleared() bool { + _, ok := m.clearedFields[user.FieldLastLoginAt] + return ok +} + +// ResetLastLoginAt resets all changes to the "last_login_at" field. +func (m *UserMutation) ResetLastLoginAt() { + m.last_login_at = nil + delete(m.clearedFields, user.FieldLastLoginAt) +} + +// SetLastActiveAt sets the "last_active_at" field. +func (m *UserMutation) SetLastActiveAt(t time.Time) { + m.last_active_at = &t +} + +// LastActiveAt returns the value of the "last_active_at" field in the mutation. +func (m *UserMutation) LastActiveAt() (r time.Time, exists bool) { + v := m.last_active_at + if v == nil { + return + } + return *v, true +} + +// OldLastActiveAt returns the old "last_active_at" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldLastActiveAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastActiveAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastActiveAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastActiveAt: %w", err) + } + return oldValue.LastActiveAt, nil +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (m *UserMutation) ClearLastActiveAt() { + m.last_active_at = nil + m.clearedFields[user.FieldLastActiveAt] = struct{}{} +} + +// LastActiveAtCleared returns if the "last_active_at" field was cleared in this mutation. +func (m *UserMutation) LastActiveAtCleared() bool { + _, ok := m.clearedFields[user.FieldLastActiveAt] + return ok +} + +// ResetLastActiveAt resets all changes to the "last_active_at" field. +func (m *UserMutation) ResetLastActiveAt() { + m.last_active_at = nil + delete(m.clearedFields, user.FieldLastActiveAt) +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (m *UserMutation) SetBalanceNotifyEnabled(b bool) { m.balance_notify_enabled = &b @@ -29762,6 +34166,114 @@ func (m *UserMutation) ResetPaymentOrders() { m.removedpayment_orders = nil } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by ids. +func (m *UserMutation) AddAuthIdentityIDs(ids ...int64) { + if m.auth_identities == nil { + m.auth_identities = make(map[int64]struct{}) + } + for i := range ids { + m.auth_identities[ids[i]] = struct{}{} + } +} + +// ClearAuthIdentities clears the "auth_identities" edge to the AuthIdentity entity. +func (m *UserMutation) ClearAuthIdentities() { + m.clearedauth_identities = true +} + +// AuthIdentitiesCleared reports if the "auth_identities" edge to the AuthIdentity entity was cleared. +func (m *UserMutation) AuthIdentitiesCleared() bool { + return m.clearedauth_identities +} + +// RemoveAuthIdentityIDs removes the "auth_identities" edge to the AuthIdentity entity by IDs. +func (m *UserMutation) RemoveAuthIdentityIDs(ids ...int64) { + if m.removedauth_identities == nil { + m.removedauth_identities = make(map[int64]struct{}) + } + for i := range ids { + delete(m.auth_identities, ids[i]) + m.removedauth_identities[ids[i]] = struct{}{} + } +} + +// RemovedAuthIdentities returns the removed IDs of the "auth_identities" edge to the AuthIdentity entity. +func (m *UserMutation) RemovedAuthIdentitiesIDs() (ids []int64) { + for id := range m.removedauth_identities { + ids = append(ids, id) + } + return +} + +// AuthIdentitiesIDs returns the "auth_identities" edge IDs in the mutation. +func (m *UserMutation) AuthIdentitiesIDs() (ids []int64) { + for id := range m.auth_identities { + ids = append(ids, id) + } + return +} + +// ResetAuthIdentities resets all changes to the "auth_identities" edge. +func (m *UserMutation) ResetAuthIdentities() { + m.auth_identities = nil + m.clearedauth_identities = false + m.removedauth_identities = nil +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by ids. +func (m *UserMutation) AddPendingAuthSessionIDs(ids ...int64) { + if m.pending_auth_sessions == nil { + m.pending_auth_sessions = make(map[int64]struct{}) + } + for i := range ids { + m.pending_auth_sessions[ids[i]] = struct{}{} + } +} + +// ClearPendingAuthSessions clears the "pending_auth_sessions" edge to the PendingAuthSession entity. +func (m *UserMutation) ClearPendingAuthSessions() { + m.clearedpending_auth_sessions = true +} + +// PendingAuthSessionsCleared reports if the "pending_auth_sessions" edge to the PendingAuthSession entity was cleared. +func (m *UserMutation) PendingAuthSessionsCleared() bool { + return m.clearedpending_auth_sessions +} + +// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (m *UserMutation) RemovePendingAuthSessionIDs(ids ...int64) { + if m.removedpending_auth_sessions == nil { + m.removedpending_auth_sessions = make(map[int64]struct{}) + } + for i := range ids { + delete(m.pending_auth_sessions, ids[i]) + m.removedpending_auth_sessions[ids[i]] = struct{}{} + } +} + +// RemovedPendingAuthSessions returns the removed IDs of the "pending_auth_sessions" edge to the PendingAuthSession entity. +func (m *UserMutation) RemovedPendingAuthSessionsIDs() (ids []int64) { + for id := range m.removedpending_auth_sessions { + ids = append(ids, id) + } + return +} + +// PendingAuthSessionsIDs returns the "pending_auth_sessions" edge IDs in the mutation. +func (m *UserMutation) PendingAuthSessionsIDs() (ids []int64) { + for id := range m.pending_auth_sessions { + ids = append(ids, id) + } + return +} + +// ResetPendingAuthSessions resets all changes to the "pending_auth_sessions" edge. +func (m *UserMutation) ResetPendingAuthSessions() { + m.pending_auth_sessions = nil + m.clearedpending_auth_sessions = false + m.removedpending_auth_sessions = nil +} + // Where appends a list predicates to the UserMutation builder. func (m *UserMutation) Where(ps ...predicate.User) { m.predicates = append(m.predicates, ps...) @@ -29796,7 +34308,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 19) + fields := make([]string, 0, 22) if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -29839,6 +34351,15 @@ func (m *UserMutation) Fields() []string { if m.totp_enabled_at != nil { fields = append(fields, user.FieldTotpEnabledAt) } + if m.signup_source != nil { + fields = append(fields, user.FieldSignupSource) + } + if m.last_login_at != nil { + fields = append(fields, user.FieldLastLoginAt) + } + if m.last_active_at != nil { + fields = append(fields, user.FieldLastActiveAt) + } if m.balance_notify_enabled != nil { fields = append(fields, user.FieldBalanceNotifyEnabled) } @@ -29890,6 +34411,12 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.TotpEnabled() case user.FieldTotpEnabledAt: return m.TotpEnabledAt() + case user.FieldSignupSource: + return m.SignupSource() + case user.FieldLastLoginAt: + return m.LastLoginAt() + case user.FieldLastActiveAt: + return m.LastActiveAt() case user.FieldBalanceNotifyEnabled: return m.BalanceNotifyEnabled() case user.FieldBalanceNotifyThresholdType: @@ -29937,6 +34464,12 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldTotpEnabled(ctx) case user.FieldTotpEnabledAt: return m.OldTotpEnabledAt(ctx) + case user.FieldSignupSource: + return m.OldSignupSource(ctx) + case user.FieldLastLoginAt: + return m.OldLastLoginAt(ctx) + case user.FieldLastActiveAt: + return m.OldLastActiveAt(ctx) case user.FieldBalanceNotifyEnabled: return m.OldBalanceNotifyEnabled(ctx) case user.FieldBalanceNotifyThresholdType: @@ -30054,6 +34587,27 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetTotpEnabledAt(v) return nil + case user.FieldSignupSource: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSignupSource(v) + return nil + case user.FieldLastLoginAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastLoginAt(v) + return nil + case user.FieldLastActiveAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastActiveAt(v) + return nil case user.FieldBalanceNotifyEnabled: v, ok := value.(bool) if !ok { @@ -30179,6 +34733,12 @@ func (m *UserMutation) ClearedFields() []string { if m.FieldCleared(user.FieldTotpEnabledAt) { fields = append(fields, user.FieldTotpEnabledAt) } + if m.FieldCleared(user.FieldLastLoginAt) { + fields = append(fields, user.FieldLastLoginAt) + } + if m.FieldCleared(user.FieldLastActiveAt) { + fields = append(fields, user.FieldLastActiveAt) + } if m.FieldCleared(user.FieldBalanceNotifyThreshold) { fields = append(fields, user.FieldBalanceNotifyThreshold) } @@ -30205,6 +34765,12 @@ func (m *UserMutation) ClearField(name string) error { case user.FieldTotpEnabledAt: m.ClearTotpEnabledAt() return nil + case user.FieldLastLoginAt: + m.ClearLastLoginAt() + return nil + case user.FieldLastActiveAt: + m.ClearLastActiveAt() + return nil case user.FieldBalanceNotifyThreshold: m.ClearBalanceNotifyThreshold() return nil @@ -30258,6 +34824,15 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldTotpEnabledAt: m.ResetTotpEnabledAt() return nil + case user.FieldSignupSource: + m.ResetSignupSource() + return nil + case user.FieldLastLoginAt: + m.ResetLastLoginAt() + return nil + case user.FieldLastActiveAt: + m.ResetLastActiveAt() + return nil case user.FieldBalanceNotifyEnabled: m.ResetBalanceNotifyEnabled() return nil @@ -30279,7 +34854,7 @@ func (m *UserMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *UserMutation) AddedEdges() []string { - edges := make([]string, 0, 10) + edges := make([]string, 0, 12) if m.api_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -30310,6 +34885,12 @@ func (m *UserMutation) AddedEdges() []string { if m.payment_orders != nil { edges = append(edges, user.EdgePaymentOrders) } + if m.auth_identities != nil { + edges = append(edges, user.EdgeAuthIdentities) + } + if m.pending_auth_sessions != nil { + edges = append(edges, user.EdgePendingAuthSessions) + } return edges } @@ -30377,13 +34958,25 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgeAuthIdentities: + ids := make([]ent.Value, 0, len(m.auth_identities)) + for id := range m.auth_identities { + ids = append(ids, id) + } + return ids + case user.EdgePendingAuthSessions: + ids := make([]ent.Value, 0, len(m.pending_auth_sessions)) + for id := range m.pending_auth_sessions { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *UserMutation) RemovedEdges() []string { - edges := make([]string, 0, 10) + edges := make([]string, 0, 12) if m.removedapi_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -30414,6 +35007,12 @@ func (m *UserMutation) RemovedEdges() []string { if m.removedpayment_orders != nil { edges = append(edges, user.EdgePaymentOrders) } + if m.removedauth_identities != nil { + edges = append(edges, user.EdgeAuthIdentities) + } + if m.removedpending_auth_sessions != nil { + edges = append(edges, user.EdgePendingAuthSessions) + } return edges } @@ -30481,13 +35080,25 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgeAuthIdentities: + ids := make([]ent.Value, 0, len(m.removedauth_identities)) + for id := range m.removedauth_identities { + ids = append(ids, id) + } + return ids + case user.EdgePendingAuthSessions: + ids := make([]ent.Value, 0, len(m.removedpending_auth_sessions)) + for id := range m.removedpending_auth_sessions { + ids = append(ids, id) + } + return ids } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *UserMutation) ClearedEdges() []string { - edges := make([]string, 0, 10) + edges := make([]string, 0, 12) if m.clearedapi_keys { edges = append(edges, user.EdgeAPIKeys) } @@ -30518,6 +35129,12 @@ func (m *UserMutation) ClearedEdges() []string { if m.clearedpayment_orders { edges = append(edges, user.EdgePaymentOrders) } + if m.clearedauth_identities { + edges = append(edges, user.EdgeAuthIdentities) + } + if m.clearedpending_auth_sessions { + edges = append(edges, user.EdgePendingAuthSessions) + } return edges } @@ -30545,6 +35162,10 @@ func (m *UserMutation) EdgeCleared(name string) bool { return m.clearedpromo_code_usages case user.EdgePaymentOrders: return m.clearedpayment_orders + case user.EdgeAuthIdentities: + return m.clearedauth_identities + case user.EdgePendingAuthSessions: + return m.clearedpending_auth_sessions } return false } @@ -30591,6 +35212,12 @@ func (m *UserMutation) ResetEdge(name string) error { case user.EdgePaymentOrders: m.ResetPaymentOrders() return nil + case user.EdgeAuthIdentities: + m.ResetAuthIdentities() + return nil + case user.EdgePendingAuthSessions: + m.ResetPendingAuthSessions() + return nil } return fmt.Errorf("unknown User edge %s", name) } diff --git a/backend/ent/pendingauthsession.go b/backend/ent/pendingauthsession.go new file mode 100644 index 00000000..e77c065f --- /dev/null +++ b/backend/ent/pendingauthsession.go @@ -0,0 +1,399 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSession is the model entity for the PendingAuthSession schema. +type PendingAuthSession struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // SessionToken holds the value of the "session_token" field. + SessionToken string `json:"session_token,omitempty"` + // Intent holds the value of the "intent" field. + Intent string `json:"intent,omitempty"` + // ProviderType holds the value of the "provider_type" field. + ProviderType string `json:"provider_type,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey string `json:"provider_key,omitempty"` + // ProviderSubject holds the value of the "provider_subject" field. + ProviderSubject string `json:"provider_subject,omitempty"` + // TargetUserID holds the value of the "target_user_id" field. + TargetUserID *int64 `json:"target_user_id,omitempty"` + // RedirectTo holds the value of the "redirect_to" field. + RedirectTo string `json:"redirect_to,omitempty"` + // ResolvedEmail holds the value of the "resolved_email" field. + ResolvedEmail string `json:"resolved_email,omitempty"` + // RegistrationPasswordHash holds the value of the "registration_password_hash" field. + RegistrationPasswordHash string `json:"registration_password_hash,omitempty"` + // UpstreamIdentityClaims holds the value of the "upstream_identity_claims" field. + UpstreamIdentityClaims map[string]interface{} `json:"upstream_identity_claims,omitempty"` + // LocalFlowState holds the value of the "local_flow_state" field. + LocalFlowState map[string]interface{} `json:"local_flow_state,omitempty"` + // BrowserSessionKey holds the value of the "browser_session_key" field. + BrowserSessionKey string `json:"browser_session_key,omitempty"` + // CompletionCodeHash holds the value of the "completion_code_hash" field. + CompletionCodeHash string `json:"completion_code_hash,omitempty"` + // CompletionCodeExpiresAt holds the value of the "completion_code_expires_at" field. + CompletionCodeExpiresAt *time.Time `json:"completion_code_expires_at,omitempty"` + // EmailVerifiedAt holds the value of the "email_verified_at" field. + EmailVerifiedAt *time.Time `json:"email_verified_at,omitempty"` + // PasswordVerifiedAt holds the value of the "password_verified_at" field. + PasswordVerifiedAt *time.Time `json:"password_verified_at,omitempty"` + // TotpVerifiedAt holds the value of the "totp_verified_at" field. + TotpVerifiedAt *time.Time `json:"totp_verified_at,omitempty"` + // ExpiresAt holds the value of the "expires_at" field. + ExpiresAt time.Time `json:"expires_at,omitempty"` + // ConsumedAt holds the value of the "consumed_at" field. + ConsumedAt *time.Time `json:"consumed_at,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the PendingAuthSessionQuery when eager-loading is set. + Edges PendingAuthSessionEdges `json:"edges"` + selectValues sql.SelectValues +} + +// PendingAuthSessionEdges holds the relations/edges for other nodes in the graph. +type PendingAuthSessionEdges struct { + // TargetUser holds the value of the target_user edge. + TargetUser *User `json:"target_user,omitempty"` + // AdoptionDecision holds the value of the adoption_decision edge. + AdoptionDecision *IdentityAdoptionDecision `json:"adoption_decision,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [2]bool +} + +// TargetUserOrErr returns the TargetUser value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e PendingAuthSessionEdges) TargetUserOrErr() (*User, error) { + if e.TargetUser != nil { + return e.TargetUser, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: user.Label} + } + return nil, &NotLoadedError{edge: "target_user"} +} + +// AdoptionDecisionOrErr returns the AdoptionDecision value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e PendingAuthSessionEdges) AdoptionDecisionOrErr() (*IdentityAdoptionDecision, error) { + if e.AdoptionDecision != nil { + return e.AdoptionDecision, nil + } else if e.loadedTypes[1] { + return nil, &NotFoundError{label: identityadoptiondecision.Label} + } + return nil, &NotLoadedError{edge: "adoption_decision"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*PendingAuthSession) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case pendingauthsession.FieldUpstreamIdentityClaims, pendingauthsession.FieldLocalFlowState: + values[i] = new([]byte) + case pendingauthsession.FieldID, pendingauthsession.FieldTargetUserID: + values[i] = new(sql.NullInt64) + case pendingauthsession.FieldSessionToken, pendingauthsession.FieldIntent, pendingauthsession.FieldProviderType, pendingauthsession.FieldProviderKey, pendingauthsession.FieldProviderSubject, pendingauthsession.FieldRedirectTo, pendingauthsession.FieldResolvedEmail, pendingauthsession.FieldRegistrationPasswordHash, pendingauthsession.FieldBrowserSessionKey, pendingauthsession.FieldCompletionCodeHash: + values[i] = new(sql.NullString) + case pendingauthsession.FieldCreatedAt, pendingauthsession.FieldUpdatedAt, pendingauthsession.FieldCompletionCodeExpiresAt, pendingauthsession.FieldEmailVerifiedAt, pendingauthsession.FieldPasswordVerifiedAt, pendingauthsession.FieldTotpVerifiedAt, pendingauthsession.FieldExpiresAt, pendingauthsession.FieldConsumedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the PendingAuthSession fields. +func (_m *PendingAuthSession) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case pendingauthsession.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case pendingauthsession.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case pendingauthsession.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case pendingauthsession.FieldSessionToken: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field session_token", values[i]) + } else if value.Valid { + _m.SessionToken = value.String + } + case pendingauthsession.FieldIntent: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field intent", values[i]) + } else if value.Valid { + _m.Intent = value.String + } + case pendingauthsession.FieldProviderType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_type", values[i]) + } else if value.Valid { + _m.ProviderType = value.String + } + case pendingauthsession.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = value.String + } + case pendingauthsession.FieldProviderSubject: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_subject", values[i]) + } else if value.Valid { + _m.ProviderSubject = value.String + } + case pendingauthsession.FieldTargetUserID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field target_user_id", values[i]) + } else if value.Valid { + _m.TargetUserID = new(int64) + *_m.TargetUserID = value.Int64 + } + case pendingauthsession.FieldRedirectTo: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field redirect_to", values[i]) + } else if value.Valid { + _m.RedirectTo = value.String + } + case pendingauthsession.FieldResolvedEmail: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field resolved_email", values[i]) + } else if value.Valid { + _m.ResolvedEmail = value.String + } + case pendingauthsession.FieldRegistrationPasswordHash: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field registration_password_hash", values[i]) + } else if value.Valid { + _m.RegistrationPasswordHash = value.String + } + case pendingauthsession.FieldUpstreamIdentityClaims: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field upstream_identity_claims", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.UpstreamIdentityClaims); err != nil { + return fmt.Errorf("unmarshal field upstream_identity_claims: %w", err) + } + } + case pendingauthsession.FieldLocalFlowState: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field local_flow_state", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.LocalFlowState); err != nil { + return fmt.Errorf("unmarshal field local_flow_state: %w", err) + } + } + case pendingauthsession.FieldBrowserSessionKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field browser_session_key", values[i]) + } else if value.Valid { + _m.BrowserSessionKey = value.String + } + case pendingauthsession.FieldCompletionCodeHash: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field completion_code_hash", values[i]) + } else if value.Valid { + _m.CompletionCodeHash = value.String + } + case pendingauthsession.FieldCompletionCodeExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field completion_code_expires_at", values[i]) + } else if value.Valid { + _m.CompletionCodeExpiresAt = new(time.Time) + *_m.CompletionCodeExpiresAt = value.Time + } + case pendingauthsession.FieldEmailVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field email_verified_at", values[i]) + } else if value.Valid { + _m.EmailVerifiedAt = new(time.Time) + *_m.EmailVerifiedAt = value.Time + } + case pendingauthsession.FieldPasswordVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field password_verified_at", values[i]) + } else if value.Valid { + _m.PasswordVerifiedAt = new(time.Time) + *_m.PasswordVerifiedAt = value.Time + } + case pendingauthsession.FieldTotpVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field totp_verified_at", values[i]) + } else if value.Valid { + _m.TotpVerifiedAt = new(time.Time) + *_m.TotpVerifiedAt = value.Time + } + case pendingauthsession.FieldExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expires_at", values[i]) + } else if value.Valid { + _m.ExpiresAt = value.Time + } + case pendingauthsession.FieldConsumedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field consumed_at", values[i]) + } else if value.Valid { + _m.ConsumedAt = new(time.Time) + *_m.ConsumedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the PendingAuthSession. +// This includes values selected through modifiers, order, etc. +func (_m *PendingAuthSession) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryTargetUser queries the "target_user" edge of the PendingAuthSession entity. +func (_m *PendingAuthSession) QueryTargetUser() *UserQuery { + return NewPendingAuthSessionClient(_m.config).QueryTargetUser(_m) +} + +// QueryAdoptionDecision queries the "adoption_decision" edge of the PendingAuthSession entity. +func (_m *PendingAuthSession) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery { + return NewPendingAuthSessionClient(_m.config).QueryAdoptionDecision(_m) +} + +// Update returns a builder for updating this PendingAuthSession. +// Note that you need to call PendingAuthSession.Unwrap() before calling this method if this PendingAuthSession +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *PendingAuthSession) Update() *PendingAuthSessionUpdateOne { + return NewPendingAuthSessionClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the PendingAuthSession entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *PendingAuthSession) Unwrap() *PendingAuthSession { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: PendingAuthSession is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *PendingAuthSession) String() string { + var builder strings.Builder + builder.WriteString("PendingAuthSession(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("session_token=") + builder.WriteString(_m.SessionToken) + builder.WriteString(", ") + builder.WriteString("intent=") + builder.WriteString(_m.Intent) + builder.WriteString(", ") + builder.WriteString("provider_type=") + builder.WriteString(_m.ProviderType) + builder.WriteString(", ") + builder.WriteString("provider_key=") + builder.WriteString(_m.ProviderKey) + builder.WriteString(", ") + builder.WriteString("provider_subject=") + builder.WriteString(_m.ProviderSubject) + builder.WriteString(", ") + if v := _m.TargetUserID; v != nil { + builder.WriteString("target_user_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("redirect_to=") + builder.WriteString(_m.RedirectTo) + builder.WriteString(", ") + builder.WriteString("resolved_email=") + builder.WriteString(_m.ResolvedEmail) + builder.WriteString(", ") + builder.WriteString("registration_password_hash=") + builder.WriteString(_m.RegistrationPasswordHash) + builder.WriteString(", ") + builder.WriteString("upstream_identity_claims=") + builder.WriteString(fmt.Sprintf("%v", _m.UpstreamIdentityClaims)) + builder.WriteString(", ") + builder.WriteString("local_flow_state=") + builder.WriteString(fmt.Sprintf("%v", _m.LocalFlowState)) + builder.WriteString(", ") + builder.WriteString("browser_session_key=") + builder.WriteString(_m.BrowserSessionKey) + builder.WriteString(", ") + builder.WriteString("completion_code_hash=") + builder.WriteString(_m.CompletionCodeHash) + builder.WriteString(", ") + if v := _m.CompletionCodeExpiresAt; v != nil { + builder.WriteString("completion_code_expires_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.EmailVerifiedAt; v != nil { + builder.WriteString("email_verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.PasswordVerifiedAt; v != nil { + builder.WriteString("password_verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.TotpVerifiedAt; v != nil { + builder.WriteString("totp_verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("expires_at=") + builder.WriteString(_m.ExpiresAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := _m.ConsumedAt; v != nil { + builder.WriteString("consumed_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteByte(')') + return builder.String() +} + +// PendingAuthSessions is a parsable slice of PendingAuthSession. +type PendingAuthSessions []*PendingAuthSession diff --git a/backend/ent/pendingauthsession/pendingauthsession.go b/backend/ent/pendingauthsession/pendingauthsession.go new file mode 100644 index 00000000..8a3ac9bf --- /dev/null +++ b/backend/ent/pendingauthsession/pendingauthsession.go @@ -0,0 +1,279 @@ +// Code generated by ent, DO NOT EDIT. + +package pendingauthsession + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the pendingauthsession type in the database. + Label = "pending_auth_session" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldSessionToken holds the string denoting the session_token field in the database. + FieldSessionToken = "session_token" + // FieldIntent holds the string denoting the intent field in the database. + FieldIntent = "intent" + // FieldProviderType holds the string denoting the provider_type field in the database. + FieldProviderType = "provider_type" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" + // FieldProviderSubject holds the string denoting the provider_subject field in the database. + FieldProviderSubject = "provider_subject" + // FieldTargetUserID holds the string denoting the target_user_id field in the database. + FieldTargetUserID = "target_user_id" + // FieldRedirectTo holds the string denoting the redirect_to field in the database. + FieldRedirectTo = "redirect_to" + // FieldResolvedEmail holds the string denoting the resolved_email field in the database. + FieldResolvedEmail = "resolved_email" + // FieldRegistrationPasswordHash holds the string denoting the registration_password_hash field in the database. + FieldRegistrationPasswordHash = "registration_password_hash" + // FieldUpstreamIdentityClaims holds the string denoting the upstream_identity_claims field in the database. + FieldUpstreamIdentityClaims = "upstream_identity_claims" + // FieldLocalFlowState holds the string denoting the local_flow_state field in the database. + FieldLocalFlowState = "local_flow_state" + // FieldBrowserSessionKey holds the string denoting the browser_session_key field in the database. + FieldBrowserSessionKey = "browser_session_key" + // FieldCompletionCodeHash holds the string denoting the completion_code_hash field in the database. + FieldCompletionCodeHash = "completion_code_hash" + // FieldCompletionCodeExpiresAt holds the string denoting the completion_code_expires_at field in the database. + FieldCompletionCodeExpiresAt = "completion_code_expires_at" + // FieldEmailVerifiedAt holds the string denoting the email_verified_at field in the database. + FieldEmailVerifiedAt = "email_verified_at" + // FieldPasswordVerifiedAt holds the string denoting the password_verified_at field in the database. + FieldPasswordVerifiedAt = "password_verified_at" + // FieldTotpVerifiedAt holds the string denoting the totp_verified_at field in the database. + FieldTotpVerifiedAt = "totp_verified_at" + // FieldExpiresAt holds the string denoting the expires_at field in the database. + FieldExpiresAt = "expires_at" + // FieldConsumedAt holds the string denoting the consumed_at field in the database. + FieldConsumedAt = "consumed_at" + // EdgeTargetUser holds the string denoting the target_user edge name in mutations. + EdgeTargetUser = "target_user" + // EdgeAdoptionDecision holds the string denoting the adoption_decision edge name in mutations. + EdgeAdoptionDecision = "adoption_decision" + // Table holds the table name of the pendingauthsession in the database. + Table = "pending_auth_sessions" + // TargetUserTable is the table that holds the target_user relation/edge. + TargetUserTable = "pending_auth_sessions" + // TargetUserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + TargetUserInverseTable = "users" + // TargetUserColumn is the table column denoting the target_user relation/edge. + TargetUserColumn = "target_user_id" + // AdoptionDecisionTable is the table that holds the adoption_decision relation/edge. + AdoptionDecisionTable = "identity_adoption_decisions" + // AdoptionDecisionInverseTable is the table name for the IdentityAdoptionDecision entity. + // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package. + AdoptionDecisionInverseTable = "identity_adoption_decisions" + // AdoptionDecisionColumn is the table column denoting the adoption_decision relation/edge. + AdoptionDecisionColumn = "pending_auth_session_id" +) + +// Columns holds all SQL columns for pendingauthsession fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldSessionToken, + FieldIntent, + FieldProviderType, + FieldProviderKey, + FieldProviderSubject, + FieldTargetUserID, + FieldRedirectTo, + FieldResolvedEmail, + FieldRegistrationPasswordHash, + FieldUpstreamIdentityClaims, + FieldLocalFlowState, + FieldBrowserSessionKey, + FieldCompletionCodeHash, + FieldCompletionCodeExpiresAt, + FieldEmailVerifiedAt, + FieldPasswordVerifiedAt, + FieldTotpVerifiedAt, + FieldExpiresAt, + FieldConsumedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save. + SessionTokenValidator func(string) error + // IntentValidator is a validator for the "intent" field. It is called by the builders before save. + IntentValidator func(string) error + // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + ProviderTypeValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error + // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + ProviderSubjectValidator func(string) error + // DefaultRedirectTo holds the default value on creation for the "redirect_to" field. + DefaultRedirectTo string + // DefaultResolvedEmail holds the default value on creation for the "resolved_email" field. + DefaultResolvedEmail string + // DefaultRegistrationPasswordHash holds the default value on creation for the "registration_password_hash" field. + DefaultRegistrationPasswordHash string + // DefaultUpstreamIdentityClaims holds the default value on creation for the "upstream_identity_claims" field. + DefaultUpstreamIdentityClaims func() map[string]interface{} + // DefaultLocalFlowState holds the default value on creation for the "local_flow_state" field. + DefaultLocalFlowState func() map[string]interface{} + // DefaultBrowserSessionKey holds the default value on creation for the "browser_session_key" field. + DefaultBrowserSessionKey string + // DefaultCompletionCodeHash holds the default value on creation for the "completion_code_hash" field. + DefaultCompletionCodeHash string +) + +// OrderOption defines the ordering options for the PendingAuthSession queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// BySessionToken orders the results by the session_token field. +func BySessionToken(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSessionToken, opts...).ToFunc() +} + +// ByIntent orders the results by the intent field. +func ByIntent(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIntent, opts...).ToFunc() +} + +// ByProviderType orders the results by the provider_type field. +func ByProviderType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderType, opts...).ToFunc() +} + +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + +// ByProviderSubject orders the results by the provider_subject field. +func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderSubject, opts...).ToFunc() +} + +// ByTargetUserID orders the results by the target_user_id field. +func ByTargetUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTargetUserID, opts...).ToFunc() +} + +// ByRedirectTo orders the results by the redirect_to field. +func ByRedirectTo(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRedirectTo, opts...).ToFunc() +} + +// ByResolvedEmail orders the results by the resolved_email field. +func ByResolvedEmail(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResolvedEmail, opts...).ToFunc() +} + +// ByRegistrationPasswordHash orders the results by the registration_password_hash field. +func ByRegistrationPasswordHash(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRegistrationPasswordHash, opts...).ToFunc() +} + +// ByBrowserSessionKey orders the results by the browser_session_key field. +func ByBrowserSessionKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBrowserSessionKey, opts...).ToFunc() +} + +// ByCompletionCodeHash orders the results by the completion_code_hash field. +func ByCompletionCodeHash(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCompletionCodeHash, opts...).ToFunc() +} + +// ByCompletionCodeExpiresAt orders the results by the completion_code_expires_at field. +func ByCompletionCodeExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCompletionCodeExpiresAt, opts...).ToFunc() +} + +// ByEmailVerifiedAt orders the results by the email_verified_at field. +func ByEmailVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEmailVerifiedAt, opts...).ToFunc() +} + +// ByPasswordVerifiedAt orders the results by the password_verified_at field. +func ByPasswordVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPasswordVerifiedAt, opts...).ToFunc() +} + +// ByTotpVerifiedAt orders the results by the totp_verified_at field. +func ByTotpVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotpVerifiedAt, opts...).ToFunc() +} + +// ByExpiresAt orders the results by the expires_at field. +func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() +} + +// ByConsumedAt orders the results by the consumed_at field. +func ByConsumedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldConsumedAt, opts...).ToFunc() +} + +// ByTargetUserField orders the results by target_user field. +func ByTargetUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newTargetUserStep(), sql.OrderByField(field, opts...)) + } +} + +// ByAdoptionDecisionField orders the results by adoption_decision field. +func ByAdoptionDecisionField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionStep(), sql.OrderByField(field, opts...)) + } +} +func newTargetUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(TargetUserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn), + ) +} +func newAdoptionDecisionStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AdoptionDecisionInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn), + ) +} diff --git a/backend/ent/pendingauthsession/where.go b/backend/ent/pendingauthsession/where.go new file mode 100644 index 00000000..cb316f47 --- /dev/null +++ b/backend/ent/pendingauthsession/where.go @@ -0,0 +1,1262 @@ +// Code generated by ent, DO NOT EDIT. + +package pendingauthsession + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// SessionToken applies equality check predicate on the "session_token" field. It's identical to SessionTokenEQ. +func SessionToken(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v)) +} + +// Intent applies equality check predicate on the "intent" field. It's identical to IntentEQ. +func Intent(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v)) +} + +// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ. +func ProviderType(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ. +func ProviderSubject(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v)) +} + +// TargetUserID applies equality check predicate on the "target_user_id" field. It's identical to TargetUserIDEQ. +func TargetUserID(v int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v)) +} + +// RedirectTo applies equality check predicate on the "redirect_to" field. It's identical to RedirectToEQ. +func RedirectTo(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v)) +} + +// ResolvedEmail applies equality check predicate on the "resolved_email" field. It's identical to ResolvedEmailEQ. +func ResolvedEmail(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v)) +} + +// RegistrationPasswordHash applies equality check predicate on the "registration_password_hash" field. It's identical to RegistrationPasswordHashEQ. +func RegistrationPasswordHash(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v)) +} + +// BrowserSessionKey applies equality check predicate on the "browser_session_key" field. It's identical to BrowserSessionKeyEQ. +func BrowserSessionKey(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v)) +} + +// CompletionCodeHash applies equality check predicate on the "completion_code_hash" field. It's identical to CompletionCodeHashEQ. +func CompletionCodeHash(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v)) +} + +// CompletionCodeExpiresAt applies equality check predicate on the "completion_code_expires_at" field. It's identical to CompletionCodeExpiresAtEQ. +func CompletionCodeExpiresAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v)) +} + +// EmailVerifiedAt applies equality check predicate on the "email_verified_at" field. It's identical to EmailVerifiedAtEQ. +func EmailVerifiedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v)) +} + +// PasswordVerifiedAt applies equality check predicate on the "password_verified_at" field. It's identical to PasswordVerifiedAtEQ. +func PasswordVerifiedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v)) +} + +// TotpVerifiedAt applies equality check predicate on the "totp_verified_at" field. It's identical to TotpVerifiedAtEQ. +func TotpVerifiedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v)) +} + +// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ. +func ExpiresAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ConsumedAt applies equality check predicate on the "consumed_at" field. It's identical to ConsumedAtEQ. +func ConsumedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// SessionTokenEQ applies the EQ predicate on the "session_token" field. +func SessionTokenEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v)) +} + +// SessionTokenNEQ applies the NEQ predicate on the "session_token" field. +func SessionTokenNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldSessionToken, v)) +} + +// SessionTokenIn applies the In predicate on the "session_token" field. +func SessionTokenIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldSessionToken, vs...)) +} + +// SessionTokenNotIn applies the NotIn predicate on the "session_token" field. +func SessionTokenNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldSessionToken, vs...)) +} + +// SessionTokenGT applies the GT predicate on the "session_token" field. +func SessionTokenGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldSessionToken, v)) +} + +// SessionTokenGTE applies the GTE predicate on the "session_token" field. +func SessionTokenGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldSessionToken, v)) +} + +// SessionTokenLT applies the LT predicate on the "session_token" field. +func SessionTokenLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldSessionToken, v)) +} + +// SessionTokenLTE applies the LTE predicate on the "session_token" field. +func SessionTokenLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldSessionToken, v)) +} + +// SessionTokenContains applies the Contains predicate on the "session_token" field. +func SessionTokenContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldSessionToken, v)) +} + +// SessionTokenHasPrefix applies the HasPrefix predicate on the "session_token" field. +func SessionTokenHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldSessionToken, v)) +} + +// SessionTokenHasSuffix applies the HasSuffix predicate on the "session_token" field. +func SessionTokenHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldSessionToken, v)) +} + +// SessionTokenEqualFold applies the EqualFold predicate on the "session_token" field. +func SessionTokenEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldSessionToken, v)) +} + +// SessionTokenContainsFold applies the ContainsFold predicate on the "session_token" field. +func SessionTokenContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldSessionToken, v)) +} + +// IntentEQ applies the EQ predicate on the "intent" field. +func IntentEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v)) +} + +// IntentNEQ applies the NEQ predicate on the "intent" field. +func IntentNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldIntent, v)) +} + +// IntentIn applies the In predicate on the "intent" field. +func IntentIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldIntent, vs...)) +} + +// IntentNotIn applies the NotIn predicate on the "intent" field. +func IntentNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldIntent, vs...)) +} + +// IntentGT applies the GT predicate on the "intent" field. +func IntentGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldIntent, v)) +} + +// IntentGTE applies the GTE predicate on the "intent" field. +func IntentGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldIntent, v)) +} + +// IntentLT applies the LT predicate on the "intent" field. +func IntentLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldIntent, v)) +} + +// IntentLTE applies the LTE predicate on the "intent" field. +func IntentLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldIntent, v)) +} + +// IntentContains applies the Contains predicate on the "intent" field. +func IntentContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldIntent, v)) +} + +// IntentHasPrefix applies the HasPrefix predicate on the "intent" field. +func IntentHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldIntent, v)) +} + +// IntentHasSuffix applies the HasSuffix predicate on the "intent" field. +func IntentHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldIntent, v)) +} + +// IntentEqualFold applies the EqualFold predicate on the "intent" field. +func IntentEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldIntent, v)) +} + +// IntentContainsFold applies the ContainsFold predicate on the "intent" field. +func IntentContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldIntent, v)) +} + +// ProviderTypeEQ applies the EQ predicate on the "provider_type" field. +func ProviderTypeEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field. +func ProviderTypeNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderType, v)) +} + +// ProviderTypeIn applies the In predicate on the "provider_type" field. +func ProviderTypeIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldProviderType, vs...)) +} + +// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field. +func ProviderTypeNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderType, vs...)) +} + +// ProviderTypeGT applies the GT predicate on the "provider_type" field. +func ProviderTypeGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldProviderType, v)) +} + +// ProviderTypeGTE applies the GTE predicate on the "provider_type" field. +func ProviderTypeGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderType, v)) +} + +// ProviderTypeLT applies the LT predicate on the "provider_type" field. +func ProviderTypeLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldProviderType, v)) +} + +// ProviderTypeLTE applies the LTE predicate on the "provider_type" field. +func ProviderTypeLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderType, v)) +} + +// ProviderTypeContains applies the Contains predicate on the "provider_type" field. +func ProviderTypeContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldProviderType, v)) +} + +// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field. +func ProviderTypeHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderType, v)) +} + +// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field. +func ProviderTypeHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderType, v)) +} + +// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field. +func ProviderTypeEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderType, v)) +} + +// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field. +func ProviderTypeContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderType, v)) +} + +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderKey, v)) +} + +// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field. +func ProviderSubjectEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field. +func ProviderSubjectNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectIn applies the In predicate on the "provider_subject" field. +func ProviderSubjectIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field. +func ProviderSubjectNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectGT applies the GT predicate on the "provider_subject" field. +func ProviderSubjectGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldProviderSubject, v)) +} + +// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field. +func ProviderSubjectGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderSubject, v)) +} + +// ProviderSubjectLT applies the LT predicate on the "provider_subject" field. +func ProviderSubjectLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldProviderSubject, v)) +} + +// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field. +func ProviderSubjectLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderSubject, v)) +} + +// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field. +func ProviderSubjectContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldProviderSubject, v)) +} + +// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field. +func ProviderSubjectHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderSubject, v)) +} + +// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field. +func ProviderSubjectHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderSubject, v)) +} + +// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field. +func ProviderSubjectEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderSubject, v)) +} + +// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field. +func ProviderSubjectContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderSubject, v)) +} + +// TargetUserIDEQ applies the EQ predicate on the "target_user_id" field. +func TargetUserIDEQ(v int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v)) +} + +// TargetUserIDNEQ applies the NEQ predicate on the "target_user_id" field. +func TargetUserIDNEQ(v int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldTargetUserID, v)) +} + +// TargetUserIDIn applies the In predicate on the "target_user_id" field. +func TargetUserIDIn(vs ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldTargetUserID, vs...)) +} + +// TargetUserIDNotIn applies the NotIn predicate on the "target_user_id" field. +func TargetUserIDNotIn(vs ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldTargetUserID, vs...)) +} + +// TargetUserIDIsNil applies the IsNil predicate on the "target_user_id" field. +func TargetUserIDIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldTargetUserID)) +} + +// TargetUserIDNotNil applies the NotNil predicate on the "target_user_id" field. +func TargetUserIDNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldTargetUserID)) +} + +// RedirectToEQ applies the EQ predicate on the "redirect_to" field. +func RedirectToEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v)) +} + +// RedirectToNEQ applies the NEQ predicate on the "redirect_to" field. +func RedirectToNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldRedirectTo, v)) +} + +// RedirectToIn applies the In predicate on the "redirect_to" field. +func RedirectToIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldRedirectTo, vs...)) +} + +// RedirectToNotIn applies the NotIn predicate on the "redirect_to" field. +func RedirectToNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldRedirectTo, vs...)) +} + +// RedirectToGT applies the GT predicate on the "redirect_to" field. +func RedirectToGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldRedirectTo, v)) +} + +// RedirectToGTE applies the GTE predicate on the "redirect_to" field. +func RedirectToGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldRedirectTo, v)) +} + +// RedirectToLT applies the LT predicate on the "redirect_to" field. +func RedirectToLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldRedirectTo, v)) +} + +// RedirectToLTE applies the LTE predicate on the "redirect_to" field. +func RedirectToLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldRedirectTo, v)) +} + +// RedirectToContains applies the Contains predicate on the "redirect_to" field. +func RedirectToContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldRedirectTo, v)) +} + +// RedirectToHasPrefix applies the HasPrefix predicate on the "redirect_to" field. +func RedirectToHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRedirectTo, v)) +} + +// RedirectToHasSuffix applies the HasSuffix predicate on the "redirect_to" field. +func RedirectToHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRedirectTo, v)) +} + +// RedirectToEqualFold applies the EqualFold predicate on the "redirect_to" field. +func RedirectToEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRedirectTo, v)) +} + +// RedirectToContainsFold applies the ContainsFold predicate on the "redirect_to" field. +func RedirectToContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRedirectTo, v)) +} + +// ResolvedEmailEQ applies the EQ predicate on the "resolved_email" field. +func ResolvedEmailEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v)) +} + +// ResolvedEmailNEQ applies the NEQ predicate on the "resolved_email" field. +func ResolvedEmailNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldResolvedEmail, v)) +} + +// ResolvedEmailIn applies the In predicate on the "resolved_email" field. +func ResolvedEmailIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldResolvedEmail, vs...)) +} + +// ResolvedEmailNotIn applies the NotIn predicate on the "resolved_email" field. +func ResolvedEmailNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldResolvedEmail, vs...)) +} + +// ResolvedEmailGT applies the GT predicate on the "resolved_email" field. +func ResolvedEmailGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldResolvedEmail, v)) +} + +// ResolvedEmailGTE applies the GTE predicate on the "resolved_email" field. +func ResolvedEmailGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldResolvedEmail, v)) +} + +// ResolvedEmailLT applies the LT predicate on the "resolved_email" field. +func ResolvedEmailLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldResolvedEmail, v)) +} + +// ResolvedEmailLTE applies the LTE predicate on the "resolved_email" field. +func ResolvedEmailLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldResolvedEmail, v)) +} + +// ResolvedEmailContains applies the Contains predicate on the "resolved_email" field. +func ResolvedEmailContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldResolvedEmail, v)) +} + +// ResolvedEmailHasPrefix applies the HasPrefix predicate on the "resolved_email" field. +func ResolvedEmailHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldResolvedEmail, v)) +} + +// ResolvedEmailHasSuffix applies the HasSuffix predicate on the "resolved_email" field. +func ResolvedEmailHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldResolvedEmail, v)) +} + +// ResolvedEmailEqualFold applies the EqualFold predicate on the "resolved_email" field. +func ResolvedEmailEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldResolvedEmail, v)) +} + +// ResolvedEmailContainsFold applies the ContainsFold predicate on the "resolved_email" field. +func ResolvedEmailContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldResolvedEmail, v)) +} + +// RegistrationPasswordHashEQ applies the EQ predicate on the "registration_password_hash" field. +func RegistrationPasswordHashEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashNEQ applies the NEQ predicate on the "registration_password_hash" field. +func RegistrationPasswordHashNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashIn applies the In predicate on the "registration_password_hash" field. +func RegistrationPasswordHashIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldRegistrationPasswordHash, vs...)) +} + +// RegistrationPasswordHashNotIn applies the NotIn predicate on the "registration_password_hash" field. +func RegistrationPasswordHashNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldRegistrationPasswordHash, vs...)) +} + +// RegistrationPasswordHashGT applies the GT predicate on the "registration_password_hash" field. +func RegistrationPasswordHashGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashGTE applies the GTE predicate on the "registration_password_hash" field. +func RegistrationPasswordHashGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashLT applies the LT predicate on the "registration_password_hash" field. +func RegistrationPasswordHashLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashLTE applies the LTE predicate on the "registration_password_hash" field. +func RegistrationPasswordHashLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashContains applies the Contains predicate on the "registration_password_hash" field. +func RegistrationPasswordHashContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashHasPrefix applies the HasPrefix predicate on the "registration_password_hash" field. +func RegistrationPasswordHashHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashHasSuffix applies the HasSuffix predicate on the "registration_password_hash" field. +func RegistrationPasswordHashHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashEqualFold applies the EqualFold predicate on the "registration_password_hash" field. +func RegistrationPasswordHashEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashContainsFold applies the ContainsFold predicate on the "registration_password_hash" field. +func RegistrationPasswordHashContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRegistrationPasswordHash, v)) +} + +// BrowserSessionKeyEQ applies the EQ predicate on the "browser_session_key" field. +func BrowserSessionKeyEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyNEQ applies the NEQ predicate on the "browser_session_key" field. +func BrowserSessionKeyNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyIn applies the In predicate on the "browser_session_key" field. +func BrowserSessionKeyIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldBrowserSessionKey, vs...)) +} + +// BrowserSessionKeyNotIn applies the NotIn predicate on the "browser_session_key" field. +func BrowserSessionKeyNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldBrowserSessionKey, vs...)) +} + +// BrowserSessionKeyGT applies the GT predicate on the "browser_session_key" field. +func BrowserSessionKeyGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyGTE applies the GTE predicate on the "browser_session_key" field. +func BrowserSessionKeyGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyLT applies the LT predicate on the "browser_session_key" field. +func BrowserSessionKeyLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyLTE applies the LTE predicate on the "browser_session_key" field. +func BrowserSessionKeyLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyContains applies the Contains predicate on the "browser_session_key" field. +func BrowserSessionKeyContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyHasPrefix applies the HasPrefix predicate on the "browser_session_key" field. +func BrowserSessionKeyHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyHasSuffix applies the HasSuffix predicate on the "browser_session_key" field. +func BrowserSessionKeyHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyEqualFold applies the EqualFold predicate on the "browser_session_key" field. +func BrowserSessionKeyEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyContainsFold applies the ContainsFold predicate on the "browser_session_key" field. +func BrowserSessionKeyContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldBrowserSessionKey, v)) +} + +// CompletionCodeHashEQ applies the EQ predicate on the "completion_code_hash" field. +func CompletionCodeHashEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashNEQ applies the NEQ predicate on the "completion_code_hash" field. +func CompletionCodeHashNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashIn applies the In predicate on the "completion_code_hash" field. +func CompletionCodeHashIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeHash, vs...)) +} + +// CompletionCodeHashNotIn applies the NotIn predicate on the "completion_code_hash" field. +func CompletionCodeHashNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeHash, vs...)) +} + +// CompletionCodeHashGT applies the GT predicate on the "completion_code_hash" field. +func CompletionCodeHashGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashGTE applies the GTE predicate on the "completion_code_hash" field. +func CompletionCodeHashGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashLT applies the LT predicate on the "completion_code_hash" field. +func CompletionCodeHashLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashLTE applies the LTE predicate on the "completion_code_hash" field. +func CompletionCodeHashLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashContains applies the Contains predicate on the "completion_code_hash" field. +func CompletionCodeHashContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashHasPrefix applies the HasPrefix predicate on the "completion_code_hash" field. +func CompletionCodeHashHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashHasSuffix applies the HasSuffix predicate on the "completion_code_hash" field. +func CompletionCodeHashHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashEqualFold applies the EqualFold predicate on the "completion_code_hash" field. +func CompletionCodeHashEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashContainsFold applies the ContainsFold predicate on the "completion_code_hash" field. +func CompletionCodeHashContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldCompletionCodeHash, v)) +} + +// CompletionCodeExpiresAtEQ applies the EQ predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtNEQ applies the NEQ predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtIn applies the In predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeExpiresAt, vs...)) +} + +// CompletionCodeExpiresAtNotIn applies the NotIn predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeExpiresAt, vs...)) +} + +// CompletionCodeExpiresAtGT applies the GT predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtGTE applies the GTE predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtLT applies the LT predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtLTE applies the LTE predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtIsNil applies the IsNil predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldCompletionCodeExpiresAt)) +} + +// CompletionCodeExpiresAtNotNil applies the NotNil predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldCompletionCodeExpiresAt)) +} + +// EmailVerifiedAtEQ applies the EQ predicate on the "email_verified_at" field. +func EmailVerifiedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtNEQ applies the NEQ predicate on the "email_verified_at" field. +func EmailVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtIn applies the In predicate on the "email_verified_at" field. +func EmailVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldEmailVerifiedAt, vs...)) +} + +// EmailVerifiedAtNotIn applies the NotIn predicate on the "email_verified_at" field. +func EmailVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldEmailVerifiedAt, vs...)) +} + +// EmailVerifiedAtGT applies the GT predicate on the "email_verified_at" field. +func EmailVerifiedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtGTE applies the GTE predicate on the "email_verified_at" field. +func EmailVerifiedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtLT applies the LT predicate on the "email_verified_at" field. +func EmailVerifiedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtLTE applies the LTE predicate on the "email_verified_at" field. +func EmailVerifiedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtIsNil applies the IsNil predicate on the "email_verified_at" field. +func EmailVerifiedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldEmailVerifiedAt)) +} + +// EmailVerifiedAtNotNil applies the NotNil predicate on the "email_verified_at" field. +func EmailVerifiedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldEmailVerifiedAt)) +} + +// PasswordVerifiedAtEQ applies the EQ predicate on the "password_verified_at" field. +func PasswordVerifiedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtNEQ applies the NEQ predicate on the "password_verified_at" field. +func PasswordVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtIn applies the In predicate on the "password_verified_at" field. +func PasswordVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldPasswordVerifiedAt, vs...)) +} + +// PasswordVerifiedAtNotIn applies the NotIn predicate on the "password_verified_at" field. +func PasswordVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldPasswordVerifiedAt, vs...)) +} + +// PasswordVerifiedAtGT applies the GT predicate on the "password_verified_at" field. +func PasswordVerifiedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtGTE applies the GTE predicate on the "password_verified_at" field. +func PasswordVerifiedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtLT applies the LT predicate on the "password_verified_at" field. +func PasswordVerifiedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtLTE applies the LTE predicate on the "password_verified_at" field. +func PasswordVerifiedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtIsNil applies the IsNil predicate on the "password_verified_at" field. +func PasswordVerifiedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldPasswordVerifiedAt)) +} + +// PasswordVerifiedAtNotNil applies the NotNil predicate on the "password_verified_at" field. +func PasswordVerifiedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldPasswordVerifiedAt)) +} + +// TotpVerifiedAtEQ applies the EQ predicate on the "totp_verified_at" field. +func TotpVerifiedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtNEQ applies the NEQ predicate on the "totp_verified_at" field. +func TotpVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtIn applies the In predicate on the "totp_verified_at" field. +func TotpVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldTotpVerifiedAt, vs...)) +} + +// TotpVerifiedAtNotIn applies the NotIn predicate on the "totp_verified_at" field. +func TotpVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldTotpVerifiedAt, vs...)) +} + +// TotpVerifiedAtGT applies the GT predicate on the "totp_verified_at" field. +func TotpVerifiedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtGTE applies the GTE predicate on the "totp_verified_at" field. +func TotpVerifiedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtLT applies the LT predicate on the "totp_verified_at" field. +func TotpVerifiedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtLTE applies the LTE predicate on the "totp_verified_at" field. +func TotpVerifiedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtIsNil applies the IsNil predicate on the "totp_verified_at" field. +func TotpVerifiedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldTotpVerifiedAt)) +} + +// TotpVerifiedAtNotNil applies the NotNil predicate on the "totp_verified_at" field. +func TotpVerifiedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldTotpVerifiedAt)) +} + +// ExpiresAtEQ applies the EQ predicate on the "expires_at" field. +func ExpiresAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field. +func ExpiresAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldExpiresAt, v)) +} + +// ExpiresAtIn applies the In predicate on the "expires_at" field. +func ExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field. +func ExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtGT applies the GT predicate on the "expires_at" field. +func ExpiresAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldExpiresAt, v)) +} + +// ExpiresAtGTE applies the GTE predicate on the "expires_at" field. +func ExpiresAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldExpiresAt, v)) +} + +// ExpiresAtLT applies the LT predicate on the "expires_at" field. +func ExpiresAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldExpiresAt, v)) +} + +// ExpiresAtLTE applies the LTE predicate on the "expires_at" field. +func ExpiresAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldExpiresAt, v)) +} + +// ConsumedAtEQ applies the EQ predicate on the "consumed_at" field. +func ConsumedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v)) +} + +// ConsumedAtNEQ applies the NEQ predicate on the "consumed_at" field. +func ConsumedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldConsumedAt, v)) +} + +// ConsumedAtIn applies the In predicate on the "consumed_at" field. +func ConsumedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldConsumedAt, vs...)) +} + +// ConsumedAtNotIn applies the NotIn predicate on the "consumed_at" field. +func ConsumedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldConsumedAt, vs...)) +} + +// ConsumedAtGT applies the GT predicate on the "consumed_at" field. +func ConsumedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldConsumedAt, v)) +} + +// ConsumedAtGTE applies the GTE predicate on the "consumed_at" field. +func ConsumedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldConsumedAt, v)) +} + +// ConsumedAtLT applies the LT predicate on the "consumed_at" field. +func ConsumedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldConsumedAt, v)) +} + +// ConsumedAtLTE applies the LTE predicate on the "consumed_at" field. +func ConsumedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldConsumedAt, v)) +} + +// ConsumedAtIsNil applies the IsNil predicate on the "consumed_at" field. +func ConsumedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldConsumedAt)) +} + +// ConsumedAtNotNil applies the NotNil predicate on the "consumed_at" field. +func ConsumedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldConsumedAt)) +} + +// HasTargetUser applies the HasEdge predicate on the "target_user" edge. +func HasTargetUser() predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasTargetUserWith applies the HasEdge predicate on the "target_user" edge with a given conditions (other predicates). +func HasTargetUserWith(preds ...predicate.User) predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := newTargetUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasAdoptionDecision applies the HasEdge predicate on the "adoption_decision" edge. +func HasAdoptionDecision() predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAdoptionDecisionWith applies the HasEdge predicate on the "adoption_decision" edge with a given conditions (other predicates). +func HasAdoptionDecisionWith(preds ...predicate.IdentityAdoptionDecision) predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := newAdoptionDecisionStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.PendingAuthSession) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.NotPredicates(p)) +} diff --git a/backend/ent/pendingauthsession_create.go b/backend/ent/pendingauthsession_create.go new file mode 100644 index 00000000..60276daa --- /dev/null +++ b/backend/ent/pendingauthsession_create.go @@ -0,0 +1,1815 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSessionCreate is the builder for creating a PendingAuthSession entity. +type PendingAuthSessionCreate struct { + config + mutation *PendingAuthSessionMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *PendingAuthSessionCreate) SetCreatedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableCreatedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *PendingAuthSessionCreate) SetUpdatedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableUpdatedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetSessionToken sets the "session_token" field. +func (_c *PendingAuthSessionCreate) SetSessionToken(v string) *PendingAuthSessionCreate { + _c.mutation.SetSessionToken(v) + return _c +} + +// SetIntent sets the "intent" field. +func (_c *PendingAuthSessionCreate) SetIntent(v string) *PendingAuthSessionCreate { + _c.mutation.SetIntent(v) + return _c +} + +// SetProviderType sets the "provider_type" field. +func (_c *PendingAuthSessionCreate) SetProviderType(v string) *PendingAuthSessionCreate { + _c.mutation.SetProviderType(v) + return _c +} + +// SetProviderKey sets the "provider_key" field. +func (_c *PendingAuthSessionCreate) SetProviderKey(v string) *PendingAuthSessionCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetProviderSubject sets the "provider_subject" field. +func (_c *PendingAuthSessionCreate) SetProviderSubject(v string) *PendingAuthSessionCreate { + _c.mutation.SetProviderSubject(v) + return _c +} + +// SetTargetUserID sets the "target_user_id" field. +func (_c *PendingAuthSessionCreate) SetTargetUserID(v int64) *PendingAuthSessionCreate { + _c.mutation.SetTargetUserID(v) + return _c +} + +// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableTargetUserID(v *int64) *PendingAuthSessionCreate { + if v != nil { + _c.SetTargetUserID(*v) + } + return _c +} + +// SetRedirectTo sets the "redirect_to" field. +func (_c *PendingAuthSessionCreate) SetRedirectTo(v string) *PendingAuthSessionCreate { + _c.mutation.SetRedirectTo(v) + return _c +} + +// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableRedirectTo(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetRedirectTo(*v) + } + return _c +} + +// SetResolvedEmail sets the "resolved_email" field. +func (_c *PendingAuthSessionCreate) SetResolvedEmail(v string) *PendingAuthSessionCreate { + _c.mutation.SetResolvedEmail(v) + return _c +} + +// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableResolvedEmail(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetResolvedEmail(*v) + } + return _c +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (_c *PendingAuthSessionCreate) SetRegistrationPasswordHash(v string) *PendingAuthSessionCreate { + _c.mutation.SetRegistrationPasswordHash(v) + return _c +} + +// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetRegistrationPasswordHash(*v) + } + return _c +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (_c *PendingAuthSessionCreate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionCreate { + _c.mutation.SetUpstreamIdentityClaims(v) + return _c +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (_c *PendingAuthSessionCreate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionCreate { + _c.mutation.SetLocalFlowState(v) + return _c +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (_c *PendingAuthSessionCreate) SetBrowserSessionKey(v string) *PendingAuthSessionCreate { + _c.mutation.SetBrowserSessionKey(v) + return _c +} + +// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetBrowserSessionKey(*v) + } + return _c +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (_c *PendingAuthSessionCreate) SetCompletionCodeHash(v string) *PendingAuthSessionCreate { + _c.mutation.SetCompletionCodeHash(v) + return _c +} + +// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetCompletionCodeHash(*v) + } + return _c +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (_c *PendingAuthSessionCreate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetCompletionCodeExpiresAt(v) + return _c +} + +// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetCompletionCodeExpiresAt(*v) + } + return _c +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (_c *PendingAuthSessionCreate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetEmailVerifiedAt(v) + return _c +} + +// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetEmailVerifiedAt(*v) + } + return _c +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (_c *PendingAuthSessionCreate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetPasswordVerifiedAt(v) + return _c +} + +// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetPasswordVerifiedAt(*v) + } + return _c +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (_c *PendingAuthSessionCreate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetTotpVerifiedAt(v) + return _c +} + +// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetTotpVerifiedAt(*v) + } + return _c +} + +// SetExpiresAt sets the "expires_at" field. +func (_c *PendingAuthSessionCreate) SetExpiresAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetExpiresAt(v) + return _c +} + +// SetConsumedAt sets the "consumed_at" field. +func (_c *PendingAuthSessionCreate) SetConsumedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetConsumedAt(v) + return _c +} + +// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetConsumedAt(*v) + } + return _c +} + +// SetTargetUser sets the "target_user" edge to the User entity. +func (_c *PendingAuthSessionCreate) SetTargetUser(v *User) *PendingAuthSessionCreate { + return _c.SetTargetUserID(v.ID) +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID. +func (_c *PendingAuthSessionCreate) SetAdoptionDecisionID(id int64) *PendingAuthSessionCreate { + _c.mutation.SetAdoptionDecisionID(id) + return _c +} + +// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionCreate { + if id != nil { + _c = _c.SetAdoptionDecisionID(*id) + } + return _c +} + +// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_c *PendingAuthSessionCreate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionCreate { + return _c.SetAdoptionDecisionID(v.ID) +} + +// Mutation returns the PendingAuthSessionMutation object of the builder. +func (_c *PendingAuthSessionCreate) Mutation() *PendingAuthSessionMutation { + return _c.mutation +} + +// Save creates the PendingAuthSession in the database. +func (_c *PendingAuthSessionCreate) Save(ctx context.Context) (*PendingAuthSession, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *PendingAuthSessionCreate) SaveX(ctx context.Context) *PendingAuthSession { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *PendingAuthSessionCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *PendingAuthSessionCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *PendingAuthSessionCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := pendingauthsession.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := pendingauthsession.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.RedirectTo(); !ok { + v := pendingauthsession.DefaultRedirectTo + _c.mutation.SetRedirectTo(v) + } + if _, ok := _c.mutation.ResolvedEmail(); !ok { + v := pendingauthsession.DefaultResolvedEmail + _c.mutation.SetResolvedEmail(v) + } + if _, ok := _c.mutation.RegistrationPasswordHash(); !ok { + v := pendingauthsession.DefaultRegistrationPasswordHash + _c.mutation.SetRegistrationPasswordHash(v) + } + if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok { + v := pendingauthsession.DefaultUpstreamIdentityClaims() + _c.mutation.SetUpstreamIdentityClaims(v) + } + if _, ok := _c.mutation.LocalFlowState(); !ok { + v := pendingauthsession.DefaultLocalFlowState() + _c.mutation.SetLocalFlowState(v) + } + if _, ok := _c.mutation.BrowserSessionKey(); !ok { + v := pendingauthsession.DefaultBrowserSessionKey + _c.mutation.SetBrowserSessionKey(v) + } + if _, ok := _c.mutation.CompletionCodeHash(); !ok { + v := pendingauthsession.DefaultCompletionCodeHash + _c.mutation.SetCompletionCodeHash(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *PendingAuthSessionCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PendingAuthSession.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "PendingAuthSession.updated_at"`)} + } + if _, ok := _c.mutation.SessionToken(); !ok { + return &ValidationError{Name: "session_token", err: errors.New(`ent: missing required field "PendingAuthSession.session_token"`)} + } + if v, ok := _c.mutation.SessionToken(); ok { + if err := pendingauthsession.SessionTokenValidator(v); err != nil { + return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)} + } + } + if _, ok := _c.mutation.Intent(); !ok { + return &ValidationError{Name: "intent", err: errors.New(`ent: missing required field "PendingAuthSession.intent"`)} + } + if v, ok := _c.mutation.Intent(); ok { + if err := pendingauthsession.IntentValidator(v); err != nil { + return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderType(); !ok { + return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "PendingAuthSession.provider_type"`)} + } + if v, ok := _c.mutation.ProviderType(); ok { + if err := pendingauthsession.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderKey(); !ok { + return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "PendingAuthSession.provider_key"`)} + } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := pendingauthsession.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderSubject(); !ok { + return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "PendingAuthSession.provider_subject"`)} + } + if v, ok := _c.mutation.ProviderSubject(); ok { + if err := pendingauthsession.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)} + } + } + if _, ok := _c.mutation.RedirectTo(); !ok { + return &ValidationError{Name: "redirect_to", err: errors.New(`ent: missing required field "PendingAuthSession.redirect_to"`)} + } + if _, ok := _c.mutation.ResolvedEmail(); !ok { + return &ValidationError{Name: "resolved_email", err: errors.New(`ent: missing required field "PendingAuthSession.resolved_email"`)} + } + if _, ok := _c.mutation.RegistrationPasswordHash(); !ok { + return &ValidationError{Name: "registration_password_hash", err: errors.New(`ent: missing required field "PendingAuthSession.registration_password_hash"`)} + } + if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok { + return &ValidationError{Name: "upstream_identity_claims", err: errors.New(`ent: missing required field "PendingAuthSession.upstream_identity_claims"`)} + } + if _, ok := _c.mutation.LocalFlowState(); !ok { + return &ValidationError{Name: "local_flow_state", err: errors.New(`ent: missing required field "PendingAuthSession.local_flow_state"`)} + } + if _, ok := _c.mutation.BrowserSessionKey(); !ok { + return &ValidationError{Name: "browser_session_key", err: errors.New(`ent: missing required field "PendingAuthSession.browser_session_key"`)} + } + if _, ok := _c.mutation.CompletionCodeHash(); !ok { + return &ValidationError{Name: "completion_code_hash", err: errors.New(`ent: missing required field "PendingAuthSession.completion_code_hash"`)} + } + if _, ok := _c.mutation.ExpiresAt(); !ok { + return &ValidationError{Name: "expires_at", err: errors.New(`ent: missing required field "PendingAuthSession.expires_at"`)} + } + return nil +} + +func (_c *PendingAuthSessionCreate) sqlSave(ctx context.Context) (*PendingAuthSession, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *PendingAuthSessionCreate) createSpec() (*PendingAuthSession, *sqlgraph.CreateSpec) { + var ( + _node = &PendingAuthSession{config: _c.config} + _spec = sqlgraph.NewCreateSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(pendingauthsession.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.SessionToken(); ok { + _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value) + _node.SessionToken = value + } + if value, ok := _c.mutation.Intent(); ok { + _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value) + _node.Intent = value + } + if value, ok := _c.mutation.ProviderType(); ok { + _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value) + _node.ProviderType = value + } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = value + } + if value, ok := _c.mutation.ProviderSubject(); ok { + _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value) + _node.ProviderSubject = value + } + if value, ok := _c.mutation.RedirectTo(); ok { + _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value) + _node.RedirectTo = value + } + if value, ok := _c.mutation.ResolvedEmail(); ok { + _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value) + _node.ResolvedEmail = value + } + if value, ok := _c.mutation.RegistrationPasswordHash(); ok { + _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value) + _node.RegistrationPasswordHash = value + } + if value, ok := _c.mutation.UpstreamIdentityClaims(); ok { + _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value) + _node.UpstreamIdentityClaims = value + } + if value, ok := _c.mutation.LocalFlowState(); ok { + _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value) + _node.LocalFlowState = value + } + if value, ok := _c.mutation.BrowserSessionKey(); ok { + _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value) + _node.BrowserSessionKey = value + } + if value, ok := _c.mutation.CompletionCodeHash(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value) + _node.CompletionCodeHash = value + } + if value, ok := _c.mutation.CompletionCodeExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value) + _node.CompletionCodeExpiresAt = &value + } + if value, ok := _c.mutation.EmailVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value) + _node.EmailVerifiedAt = &value + } + if value, ok := _c.mutation.PasswordVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value) + _node.PasswordVerifiedAt = &value + } + if value, ok := _c.mutation.TotpVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value) + _node.TotpVerifiedAt = &value + } + if value, ok := _c.mutation.ExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value) + _node.ExpiresAt = value + } + if value, ok := _c.mutation.ConsumedAt(); ok { + _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value) + _node.ConsumedAt = &value + } + if nodes := _c.mutation.TargetUserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.TargetUserID = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.AdoptionDecisionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.PendingAuthSession.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.PendingAuthSessionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *PendingAuthSessionCreate) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertOne { + _c.conflict = opts + return &PendingAuthSessionUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *PendingAuthSessionCreate) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &PendingAuthSessionUpsertOne{ + create: _c, + } +} + +type ( + // PendingAuthSessionUpsertOne is the builder for "upsert"-ing + // one PendingAuthSession node. + PendingAuthSessionUpsertOne struct { + create *PendingAuthSessionCreate + } + + // PendingAuthSessionUpsert is the "OnConflict" setter. + PendingAuthSessionUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *PendingAuthSessionUpsert) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateUpdatedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldUpdatedAt) + return u +} + +// SetSessionToken sets the "session_token" field. +func (u *PendingAuthSessionUpsert) SetSessionToken(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldSessionToken, v) + return u +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateSessionToken() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldSessionToken) + return u +} + +// SetIntent sets the "intent" field. +func (u *PendingAuthSessionUpsert) SetIntent(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldIntent, v) + return u +} + +// UpdateIntent sets the "intent" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateIntent() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldIntent) + return u +} + +// SetProviderType sets the "provider_type" field. +func (u *PendingAuthSessionUpsert) SetProviderType(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldProviderType, v) + return u +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateProviderType() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldProviderType) + return u +} + +// SetProviderKey sets the "provider_key" field. +func (u *PendingAuthSessionUpsert) SetProviderKey(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateProviderKey() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldProviderKey) + return u +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *PendingAuthSessionUpsert) SetProviderSubject(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldProviderSubject, v) + return u +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateProviderSubject() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldProviderSubject) + return u +} + +// SetTargetUserID sets the "target_user_id" field. +func (u *PendingAuthSessionUpsert) SetTargetUserID(v int64) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldTargetUserID, v) + return u +} + +// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateTargetUserID() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldTargetUserID) + return u +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (u *PendingAuthSessionUpsert) ClearTargetUserID() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldTargetUserID) + return u +} + +// SetRedirectTo sets the "redirect_to" field. +func (u *PendingAuthSessionUpsert) SetRedirectTo(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldRedirectTo, v) + return u +} + +// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateRedirectTo() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldRedirectTo) + return u +} + +// SetResolvedEmail sets the "resolved_email" field. +func (u *PendingAuthSessionUpsert) SetResolvedEmail(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldResolvedEmail, v) + return u +} + +// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateResolvedEmail() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldResolvedEmail) + return u +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (u *PendingAuthSessionUpsert) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldRegistrationPasswordHash, v) + return u +} + +// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldRegistrationPasswordHash) + return u +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (u *PendingAuthSessionUpsert) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldUpstreamIdentityClaims, v) + return u +} + +// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldUpstreamIdentityClaims) + return u +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (u *PendingAuthSessionUpsert) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldLocalFlowState, v) + return u +} + +// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateLocalFlowState() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldLocalFlowState) + return u +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (u *PendingAuthSessionUpsert) SetBrowserSessionKey(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldBrowserSessionKey, v) + return u +} + +// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateBrowserSessionKey() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldBrowserSessionKey) + return u +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (u *PendingAuthSessionUpsert) SetCompletionCodeHash(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldCompletionCodeHash, v) + return u +} + +// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateCompletionCodeHash() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldCompletionCodeHash) + return u +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsert) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldCompletionCodeExpiresAt, v) + return u +} + +// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldCompletionCodeExpiresAt) + return u +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsert) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldCompletionCodeExpiresAt) + return u +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (u *PendingAuthSessionUpsert) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldEmailVerifiedAt, v) + return u +} + +// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateEmailVerifiedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldEmailVerifiedAt) + return u +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (u *PendingAuthSessionUpsert) ClearEmailVerifiedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldEmailVerifiedAt) + return u +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (u *PendingAuthSessionUpsert) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldPasswordVerifiedAt, v) + return u +} + +// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldPasswordVerifiedAt) + return u +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (u *PendingAuthSessionUpsert) ClearPasswordVerifiedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldPasswordVerifiedAt) + return u +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (u *PendingAuthSessionUpsert) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldTotpVerifiedAt, v) + return u +} + +// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateTotpVerifiedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldTotpVerifiedAt) + return u +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (u *PendingAuthSessionUpsert) ClearTotpVerifiedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldTotpVerifiedAt) + return u +} + +// SetExpiresAt sets the "expires_at" field. +func (u *PendingAuthSessionUpsert) SetExpiresAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldExpiresAt, v) + return u +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateExpiresAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldExpiresAt) + return u +} + +// SetConsumedAt sets the "consumed_at" field. +func (u *PendingAuthSessionUpsert) SetConsumedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldConsumedAt, v) + return u +} + +// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateConsumedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldConsumedAt) + return u +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (u *PendingAuthSessionUpsert) ClearConsumedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldConsumedAt) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *PendingAuthSessionUpsertOne) UpdateNewValues() *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(pendingauthsession.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *PendingAuthSessionUpsertOne) Ignore() *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *PendingAuthSessionUpsertOne) DoNothing() *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreate.OnConflict +// documentation for more info. +func (u *PendingAuthSessionUpsertOne) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&PendingAuthSessionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *PendingAuthSessionUpsertOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateUpdatedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetSessionToken sets the "session_token" field. +func (u *PendingAuthSessionUpsertOne) SetSessionToken(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetSessionToken(v) + }) +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateSessionToken() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateSessionToken() + }) +} + +// SetIntent sets the "intent" field. +func (u *PendingAuthSessionUpsertOne) SetIntent(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetIntent(v) + }) +} + +// UpdateIntent sets the "intent" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateIntent() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateIntent() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *PendingAuthSessionUpsertOne) SetProviderType(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateProviderType() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *PendingAuthSessionUpsertOne) SetProviderKey(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateProviderKey() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *PendingAuthSessionUpsertOne) SetProviderSubject(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateProviderSubject() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetTargetUserID sets the "target_user_id" field. +func (u *PendingAuthSessionUpsertOne) SetTargetUserID(v int64) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTargetUserID(v) + }) +} + +// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateTargetUserID() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTargetUserID() + }) +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (u *PendingAuthSessionUpsertOne) ClearTargetUserID() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTargetUserID() + }) +} + +// SetRedirectTo sets the "redirect_to" field. +func (u *PendingAuthSessionUpsertOne) SetRedirectTo(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRedirectTo(v) + }) +} + +// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateRedirectTo() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRedirectTo() + }) +} + +// SetResolvedEmail sets the "resolved_email" field. +func (u *PendingAuthSessionUpsertOne) SetResolvedEmail(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetResolvedEmail(v) + }) +} + +// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateResolvedEmail() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateResolvedEmail() + }) +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (u *PendingAuthSessionUpsertOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRegistrationPasswordHash(v) + }) +} + +// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRegistrationPasswordHash() + }) +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (u *PendingAuthSessionUpsertOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpstreamIdentityClaims(v) + }) +} + +// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpstreamIdentityClaims() + }) +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (u *PendingAuthSessionUpsertOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetLocalFlowState(v) + }) +} + +// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateLocalFlowState() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateLocalFlowState() + }) +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (u *PendingAuthSessionUpsertOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetBrowserSessionKey(v) + }) +} + +// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateBrowserSessionKey() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateBrowserSessionKey() + }) +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (u *PendingAuthSessionUpsertOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeHash(v) + }) +} + +// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeHash() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeHash() + }) +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeExpiresAt(v) + }) +} + +// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeExpiresAt() + }) +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearCompletionCodeExpiresAt() + }) +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (u *PendingAuthSessionUpsertOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetEmailVerifiedAt(v) + }) +} + +// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateEmailVerifiedAt() + }) +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (u *PendingAuthSessionUpsertOne) ClearEmailVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearEmailVerifiedAt() + }) +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (u *PendingAuthSessionUpsertOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetPasswordVerifiedAt(v) + }) +} + +// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdatePasswordVerifiedAt() + }) +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (u *PendingAuthSessionUpsertOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearPasswordVerifiedAt() + }) +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTotpVerifiedAt(v) + }) +} + +// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTotpVerifiedAt() + }) +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertOne) ClearTotpVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTotpVerifiedAt() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *PendingAuthSessionUpsertOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateExpiresAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateExpiresAt() + }) +} + +// SetConsumedAt sets the "consumed_at" field. +func (u *PendingAuthSessionUpsertOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetConsumedAt(v) + }) +} + +// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateConsumedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateConsumedAt() + }) +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (u *PendingAuthSessionUpsertOne) ClearConsumedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearConsumedAt() + }) +} + +// Exec executes the query. +func (u *PendingAuthSessionUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for PendingAuthSessionCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *PendingAuthSessionUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *PendingAuthSessionUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *PendingAuthSessionUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// PendingAuthSessionCreateBulk is the builder for creating many PendingAuthSession entities in bulk. +type PendingAuthSessionCreateBulk struct { + config + err error + builders []*PendingAuthSessionCreate + conflict []sql.ConflictOption +} + +// Save creates the PendingAuthSession entities in the database. +func (_c *PendingAuthSessionCreateBulk) Save(ctx context.Context) ([]*PendingAuthSession, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*PendingAuthSession, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*PendingAuthSessionMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *PendingAuthSessionCreateBulk) SaveX(ctx context.Context) []*PendingAuthSession { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *PendingAuthSessionCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *PendingAuthSessionCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.PendingAuthSession.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.PendingAuthSessionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *PendingAuthSessionCreateBulk) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertBulk { + _c.conflict = opts + return &PendingAuthSessionUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *PendingAuthSessionCreateBulk) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &PendingAuthSessionUpsertBulk{ + create: _c, + } +} + +// PendingAuthSessionUpsertBulk is the builder for "upsert"-ing +// a bulk of PendingAuthSession nodes. +type PendingAuthSessionUpsertBulk struct { + create *PendingAuthSessionCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *PendingAuthSessionUpsertBulk) UpdateNewValues() *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(pendingauthsession.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *PendingAuthSessionUpsertBulk) Ignore() *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *PendingAuthSessionUpsertBulk) DoNothing() *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreateBulk.OnConflict +// documentation for more info. +func (u *PendingAuthSessionUpsertBulk) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&PendingAuthSessionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *PendingAuthSessionUpsertBulk) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateUpdatedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetSessionToken sets the "session_token" field. +func (u *PendingAuthSessionUpsertBulk) SetSessionToken(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetSessionToken(v) + }) +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateSessionToken() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateSessionToken() + }) +} + +// SetIntent sets the "intent" field. +func (u *PendingAuthSessionUpsertBulk) SetIntent(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetIntent(v) + }) +} + +// UpdateIntent sets the "intent" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateIntent() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateIntent() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *PendingAuthSessionUpsertBulk) SetProviderType(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateProviderType() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *PendingAuthSessionUpsertBulk) SetProviderKey(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateProviderKey() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *PendingAuthSessionUpsertBulk) SetProviderSubject(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateProviderSubject() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetTargetUserID sets the "target_user_id" field. +func (u *PendingAuthSessionUpsertBulk) SetTargetUserID(v int64) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTargetUserID(v) + }) +} + +// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateTargetUserID() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTargetUserID() + }) +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (u *PendingAuthSessionUpsertBulk) ClearTargetUserID() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTargetUserID() + }) +} + +// SetRedirectTo sets the "redirect_to" field. +func (u *PendingAuthSessionUpsertBulk) SetRedirectTo(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRedirectTo(v) + }) +} + +// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateRedirectTo() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRedirectTo() + }) +} + +// SetResolvedEmail sets the "resolved_email" field. +func (u *PendingAuthSessionUpsertBulk) SetResolvedEmail(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetResolvedEmail(v) + }) +} + +// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateResolvedEmail() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateResolvedEmail() + }) +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (u *PendingAuthSessionUpsertBulk) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRegistrationPasswordHash(v) + }) +} + +// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRegistrationPasswordHash() + }) +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (u *PendingAuthSessionUpsertBulk) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpstreamIdentityClaims(v) + }) +} + +// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpstreamIdentityClaims() + }) +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (u *PendingAuthSessionUpsertBulk) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetLocalFlowState(v) + }) +} + +// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateLocalFlowState() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateLocalFlowState() + }) +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (u *PendingAuthSessionUpsertBulk) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetBrowserSessionKey(v) + }) +} + +// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateBrowserSessionKey() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateBrowserSessionKey() + }) +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeHash(v) + }) +} + +// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeHash() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeHash() + }) +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeExpiresAt(v) + }) +} + +// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeExpiresAt() + }) +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearCompletionCodeExpiresAt() + }) +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetEmailVerifiedAt(v) + }) +} + +// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateEmailVerifiedAt() + }) +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearEmailVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearEmailVerifiedAt() + }) +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetPasswordVerifiedAt(v) + }) +} + +// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdatePasswordVerifiedAt() + }) +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearPasswordVerifiedAt() + }) +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTotpVerifiedAt(v) + }) +} + +// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTotpVerifiedAt() + }) +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearTotpVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTotpVerifiedAt() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *PendingAuthSessionUpsertBulk) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateExpiresAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateExpiresAt() + }) +} + +// SetConsumedAt sets the "consumed_at" field. +func (u *PendingAuthSessionUpsertBulk) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetConsumedAt(v) + }) +} + +// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateConsumedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateConsumedAt() + }) +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearConsumedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearConsumedAt() + }) +} + +// Exec executes the query. +func (u *PendingAuthSessionUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the PendingAuthSessionCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for PendingAuthSessionCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *PendingAuthSessionUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/pendingauthsession_delete.go b/backend/ent/pendingauthsession_delete.go new file mode 100644 index 00000000..ee4fe605 --- /dev/null +++ b/backend/ent/pendingauthsession_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// PendingAuthSessionDelete is the builder for deleting a PendingAuthSession entity. +type PendingAuthSessionDelete struct { + config + hooks []Hook + mutation *PendingAuthSessionMutation +} + +// Where appends a list predicates to the PendingAuthSessionDelete builder. +func (_d *PendingAuthSessionDelete) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *PendingAuthSessionDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *PendingAuthSessionDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *PendingAuthSessionDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// PendingAuthSessionDeleteOne is the builder for deleting a single PendingAuthSession entity. +type PendingAuthSessionDeleteOne struct { + _d *PendingAuthSessionDelete +} + +// Where appends a list predicates to the PendingAuthSessionDelete builder. +func (_d *PendingAuthSessionDeleteOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *PendingAuthSessionDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{pendingauthsession.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *PendingAuthSessionDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/pendingauthsession_query.go b/backend/ent/pendingauthsession_query.go new file mode 100644 index 00000000..78e29cd2 --- /dev/null +++ b/backend/ent/pendingauthsession_query.go @@ -0,0 +1,717 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSessionQuery is the builder for querying PendingAuthSession entities. +type PendingAuthSessionQuery struct { + config + ctx *QueryContext + order []pendingauthsession.OrderOption + inters []Interceptor + predicates []predicate.PendingAuthSession + withTargetUser *UserQuery + withAdoptionDecision *IdentityAdoptionDecisionQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the PendingAuthSessionQuery builder. +func (_q *PendingAuthSessionQuery) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *PendingAuthSessionQuery) Limit(limit int) *PendingAuthSessionQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *PendingAuthSessionQuery) Offset(offset int) *PendingAuthSessionQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *PendingAuthSessionQuery) Unique(unique bool) *PendingAuthSessionQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *PendingAuthSessionQuery) Order(o ...pendingauthsession.OrderOption) *PendingAuthSessionQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryTargetUser chains the current query on the "target_user" edge. +func (_q *PendingAuthSessionQuery) QueryTargetUser() *UserQuery { + query := (&UserClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryAdoptionDecision chains the current query on the "adoption_decision" edge. +func (_q *PendingAuthSessionQuery) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first PendingAuthSession entity from the query. +// Returns a *NotFoundError when no PendingAuthSession was found. +func (_q *PendingAuthSessionQuery) First(ctx context.Context) (*PendingAuthSession, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{pendingauthsession.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) FirstX(ctx context.Context) *PendingAuthSession { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first PendingAuthSession ID from the query. +// Returns a *NotFoundError when no PendingAuthSession ID was found. +func (_q *PendingAuthSessionQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{pendingauthsession.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single PendingAuthSession entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one PendingAuthSession entity is found. +// Returns a *NotFoundError when no PendingAuthSession entities are found. +func (_q *PendingAuthSessionQuery) Only(ctx context.Context) (*PendingAuthSession, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{pendingauthsession.Label} + default: + return nil, &NotSingularError{pendingauthsession.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) OnlyX(ctx context.Context) *PendingAuthSession { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only PendingAuthSession ID in the query. +// Returns a *NotSingularError when more than one PendingAuthSession ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *PendingAuthSessionQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{pendingauthsession.Label} + default: + err = &NotSingularError{pendingauthsession.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of PendingAuthSessions. +func (_q *PendingAuthSessionQuery) All(ctx context.Context) ([]*PendingAuthSession, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*PendingAuthSession, *PendingAuthSessionQuery]() + return withInterceptors[[]*PendingAuthSession](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) AllX(ctx context.Context) []*PendingAuthSession { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of PendingAuthSession IDs. +func (_q *PendingAuthSessionQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(pendingauthsession.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *PendingAuthSessionQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*PendingAuthSessionQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *PendingAuthSessionQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the PendingAuthSessionQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *PendingAuthSessionQuery) Clone() *PendingAuthSessionQuery { + if _q == nil { + return nil + } + return &PendingAuthSessionQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]pendingauthsession.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.PendingAuthSession{}, _q.predicates...), + withTargetUser: _q.withTargetUser.Clone(), + withAdoptionDecision: _q.withAdoptionDecision.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithTargetUser tells the query-builder to eager-load the nodes that are connected to +// the "target_user" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *PendingAuthSessionQuery) WithTargetUser(opts ...func(*UserQuery)) *PendingAuthSessionQuery { + query := (&UserClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withTargetUser = query + return _q +} + +// WithAdoptionDecision tells the query-builder to eager-load the nodes that are connected to +// the "adoption_decision" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *PendingAuthSessionQuery) WithAdoptionDecision(opts ...func(*IdentityAdoptionDecisionQuery)) *PendingAuthSessionQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAdoptionDecision = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.PendingAuthSession.Query(). +// GroupBy(pendingauthsession.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *PendingAuthSessionQuery) GroupBy(field string, fields ...string) *PendingAuthSessionGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &PendingAuthSessionGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = pendingauthsession.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.PendingAuthSession.Query(). +// Select(pendingauthsession.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *PendingAuthSessionQuery) Select(fields ...string) *PendingAuthSessionSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &PendingAuthSessionSelect{PendingAuthSessionQuery: _q} + sbuild.label = pendingauthsession.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a PendingAuthSessionSelect configured with the given aggregations. +func (_q *PendingAuthSessionQuery) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *PendingAuthSessionQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !pendingauthsession.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *PendingAuthSessionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*PendingAuthSession, error) { + var ( + nodes = []*PendingAuthSession{} + _spec = _q.querySpec() + loadedTypes = [2]bool{ + _q.withTargetUser != nil, + _q.withAdoptionDecision != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*PendingAuthSession).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &PendingAuthSession{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withTargetUser; query != nil { + if err := _q.loadTargetUser(ctx, query, nodes, nil, + func(n *PendingAuthSession, e *User) { n.Edges.TargetUser = e }); err != nil { + return nil, err + } + } + if query := _q.withAdoptionDecision; query != nil { + if err := _q.loadAdoptionDecision(ctx, query, nodes, nil, + func(n *PendingAuthSession, e *IdentityAdoptionDecision) { n.Edges.AdoptionDecision = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *PendingAuthSessionQuery) loadTargetUser(ctx context.Context, query *UserQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *User)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*PendingAuthSession) + for i := range nodes { + if nodes[i].TargetUserID == nil { + continue + } + fk := *nodes[i].TargetUserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "target_user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *PendingAuthSessionQuery) loadAdoptionDecision(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *IdentityAdoptionDecision)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*PendingAuthSession) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(identityadoptiondecision.FieldPendingAuthSessionID) + } + query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(pendingauthsession.AdoptionDecisionColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.PendingAuthSessionID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "pending_auth_session_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} + +func (_q *PendingAuthSessionQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *PendingAuthSessionQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID) + for i := range fields { + if fields[i] != pendingauthsession.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withTargetUser != nil { + _spec.Node.AddColumnOnce(pendingauthsession.FieldTargetUserID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *PendingAuthSessionQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(pendingauthsession.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = pendingauthsession.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *PendingAuthSessionQuery) ForUpdate(opts ...sql.LockOption) *PendingAuthSessionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *PendingAuthSessionQuery) ForShare(opts ...sql.LockOption) *PendingAuthSessionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// PendingAuthSessionGroupBy is the group-by builder for PendingAuthSession entities. +type PendingAuthSessionGroupBy struct { + selector + build *PendingAuthSessionQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *PendingAuthSessionGroupBy) Aggregate(fns ...AggregateFunc) *PendingAuthSessionGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *PendingAuthSessionGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *PendingAuthSessionGroupBy) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// PendingAuthSessionSelect is the builder for selecting fields of PendingAuthSession entities. +type PendingAuthSessionSelect struct { + *PendingAuthSessionQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *PendingAuthSessionSelect) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *PendingAuthSessionSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionSelect](ctx, _s.PendingAuthSessionQuery, _s, _s.inters, v) +} + +func (_s *PendingAuthSessionSelect) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/pendingauthsession_update.go b/backend/ent/pendingauthsession_update.go new file mode 100644 index 00000000..00066f69 --- /dev/null +++ b/backend/ent/pendingauthsession_update.go @@ -0,0 +1,1178 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSessionUpdate is the builder for updating PendingAuthSession entities. +type PendingAuthSessionUpdate struct { + config + hooks []Hook + mutation *PendingAuthSessionMutation +} + +// Where appends a list predicates to the PendingAuthSessionUpdate builder. +func (_u *PendingAuthSessionUpdate) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *PendingAuthSessionUpdate) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetSessionToken sets the "session_token" field. +func (_u *PendingAuthSessionUpdate) SetSessionToken(v string) *PendingAuthSessionUpdate { + _u.mutation.SetSessionToken(v) + return _u +} + +// SetNillableSessionToken sets the "session_token" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableSessionToken(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetSessionToken(*v) + } + return _u +} + +// SetIntent sets the "intent" field. +func (_u *PendingAuthSessionUpdate) SetIntent(v string) *PendingAuthSessionUpdate { + _u.mutation.SetIntent(v) + return _u +} + +// SetNillableIntent sets the "intent" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableIntent(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetIntent(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *PendingAuthSessionUpdate) SetProviderType(v string) *PendingAuthSessionUpdate { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableProviderType(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *PendingAuthSessionUpdate) SetProviderKey(v string) *PendingAuthSessionUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableProviderKey(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *PendingAuthSessionUpdate) SetProviderSubject(v string) *PendingAuthSessionUpdate { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetTargetUserID sets the "target_user_id" field. +func (_u *PendingAuthSessionUpdate) SetTargetUserID(v int64) *PendingAuthSessionUpdate { + _u.mutation.SetTargetUserID(v) + return _u +} + +// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdate { + if v != nil { + _u.SetTargetUserID(*v) + } + return _u +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (_u *PendingAuthSessionUpdate) ClearTargetUserID() *PendingAuthSessionUpdate { + _u.mutation.ClearTargetUserID() + return _u +} + +// SetRedirectTo sets the "redirect_to" field. +func (_u *PendingAuthSessionUpdate) SetRedirectTo(v string) *PendingAuthSessionUpdate { + _u.mutation.SetRedirectTo(v) + return _u +} + +// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetRedirectTo(*v) + } + return _u +} + +// SetResolvedEmail sets the "resolved_email" field. +func (_u *PendingAuthSessionUpdate) SetResolvedEmail(v string) *PendingAuthSessionUpdate { + _u.mutation.SetResolvedEmail(v) + return _u +} + +// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetResolvedEmail(*v) + } + return _u +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (_u *PendingAuthSessionUpdate) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdate { + _u.mutation.SetRegistrationPasswordHash(v) + return _u +} + +// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetRegistrationPasswordHash(*v) + } + return _u +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (_u *PendingAuthSessionUpdate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdate { + _u.mutation.SetUpstreamIdentityClaims(v) + return _u +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (_u *PendingAuthSessionUpdate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdate { + _u.mutation.SetLocalFlowState(v) + return _u +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (_u *PendingAuthSessionUpdate) SetBrowserSessionKey(v string) *PendingAuthSessionUpdate { + _u.mutation.SetBrowserSessionKey(v) + return _u +} + +// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetBrowserSessionKey(*v) + } + return _u +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (_u *PendingAuthSessionUpdate) SetCompletionCodeHash(v string) *PendingAuthSessionUpdate { + _u.mutation.SetCompletionCodeHash(v) + return _u +} + +// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetCompletionCodeHash(*v) + } + return _u +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetCompletionCodeExpiresAt(v) + return _u +} + +// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetCompletionCodeExpiresAt(*v) + } + return _u +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdate) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdate { + _u.mutation.ClearCompletionCodeExpiresAt() + return _u +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (_u *PendingAuthSessionUpdate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetEmailVerifiedAt(v) + return _u +} + +// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetEmailVerifiedAt(*v) + } + return _u +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (_u *PendingAuthSessionUpdate) ClearEmailVerifiedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearEmailVerifiedAt() + return _u +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (_u *PendingAuthSessionUpdate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetPasswordVerifiedAt(v) + return _u +} + +// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetPasswordVerifiedAt(*v) + } + return _u +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (_u *PendingAuthSessionUpdate) ClearPasswordVerifiedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearPasswordVerifiedAt() + return _u +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetTotpVerifiedAt(v) + return _u +} + +// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetTotpVerifiedAt(*v) + } + return _u +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdate) ClearTotpVerifiedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearTotpVerifiedAt() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *PendingAuthSessionUpdate) SetExpiresAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// SetConsumedAt sets the "consumed_at" field. +func (_u *PendingAuthSessionUpdate) SetConsumedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetConsumedAt(v) + return _u +} + +// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetConsumedAt(*v) + } + return _u +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (_u *PendingAuthSessionUpdate) ClearConsumedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearConsumedAt() + return _u +} + +// SetTargetUser sets the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdate) SetTargetUser(v *User) *PendingAuthSessionUpdate { + return _u.SetTargetUserID(v.ID) +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID. +func (_u *PendingAuthSessionUpdate) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdate { + _u.mutation.SetAdoptionDecisionID(id) + return _u +} + +// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdate { + if id != nil { + _u = _u.SetAdoptionDecisionID(*id) + } + return _u +} + +// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdate { + return _u.SetAdoptionDecisionID(v.ID) +} + +// Mutation returns the PendingAuthSessionMutation object of the builder. +func (_u *PendingAuthSessionUpdate) Mutation() *PendingAuthSessionMutation { + return _u.mutation +} + +// ClearTargetUser clears the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdate) ClearTargetUser() *PendingAuthSessionUpdate { + _u.mutation.ClearTargetUser() + return _u +} + +// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdate) ClearAdoptionDecision() *PendingAuthSessionUpdate { + _u.mutation.ClearAdoptionDecision() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *PendingAuthSessionUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *PendingAuthSessionUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *PendingAuthSessionUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *PendingAuthSessionUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *PendingAuthSessionUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := pendingauthsession.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *PendingAuthSessionUpdate) check() error { + if v, ok := _u.mutation.SessionToken(); ok { + if err := pendingauthsession.SessionTokenValidator(v); err != nil { + return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)} + } + } + if v, ok := _u.mutation.Intent(); ok { + if err := pendingauthsession.IntentValidator(v); err != nil { + return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderType(); ok { + if err := pendingauthsession.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := pendingauthsession.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := pendingauthsession.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)} + } + } + return nil +} + +func (_u *PendingAuthSessionUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.SessionToken(); ok { + _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value) + } + if value, ok := _u.mutation.Intent(); ok { + _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.RedirectTo(); ok { + _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value) + } + if value, ok := _u.mutation.ResolvedEmail(); ok { + _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value) + } + if value, ok := _u.mutation.RegistrationPasswordHash(); ok { + _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value) + } + if value, ok := _u.mutation.UpstreamIdentityClaims(); ok { + _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value) + } + if value, ok := _u.mutation.LocalFlowState(); ok { + _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value) + } + if value, ok := _u.mutation.BrowserSessionKey(); ok { + _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeHash(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value) + } + if _u.mutation.CompletionCodeExpiresAtCleared() { + _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime) + } + if value, ok := _u.mutation.EmailVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value) + } + if _u.mutation.EmailVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.PasswordVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value) + } + if _u.mutation.PasswordVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.TotpVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value) + } + if _u.mutation.TotpVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ConsumedAt(); ok { + _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value) + } + if _u.mutation.ConsumedAtCleared() { + _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime) + } + if _u.mutation.TargetUserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{pendingauthsession.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// PendingAuthSessionUpdateOne is the builder for updating a single PendingAuthSession entity. +type PendingAuthSessionUpdateOne struct { + config + fields []string + hooks []Hook + mutation *PendingAuthSessionMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *PendingAuthSessionUpdateOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetSessionToken sets the "session_token" field. +func (_u *PendingAuthSessionUpdateOne) SetSessionToken(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetSessionToken(v) + return _u +} + +// SetNillableSessionToken sets the "session_token" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableSessionToken(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetSessionToken(*v) + } + return _u +} + +// SetIntent sets the "intent" field. +func (_u *PendingAuthSessionUpdateOne) SetIntent(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetIntent(v) + return _u +} + +// SetNillableIntent sets the "intent" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableIntent(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetIntent(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *PendingAuthSessionUpdateOne) SetProviderType(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableProviderType(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *PendingAuthSessionUpdateOne) SetProviderKey(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableProviderKey(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *PendingAuthSessionUpdateOne) SetProviderSubject(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetTargetUserID sets the "target_user_id" field. +func (_u *PendingAuthSessionUpdateOne) SetTargetUserID(v int64) *PendingAuthSessionUpdateOne { + _u.mutation.SetTargetUserID(v) + return _u +} + +// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetTargetUserID(*v) + } + return _u +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (_u *PendingAuthSessionUpdateOne) ClearTargetUserID() *PendingAuthSessionUpdateOne { + _u.mutation.ClearTargetUserID() + return _u +} + +// SetRedirectTo sets the "redirect_to" field. +func (_u *PendingAuthSessionUpdateOne) SetRedirectTo(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetRedirectTo(v) + return _u +} + +// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetRedirectTo(*v) + } + return _u +} + +// SetResolvedEmail sets the "resolved_email" field. +func (_u *PendingAuthSessionUpdateOne) SetResolvedEmail(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetResolvedEmail(v) + return _u +} + +// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetResolvedEmail(*v) + } + return _u +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (_u *PendingAuthSessionUpdateOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetRegistrationPasswordHash(v) + return _u +} + +// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetRegistrationPasswordHash(*v) + } + return _u +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (_u *PendingAuthSessionUpdateOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdateOne { + _u.mutation.SetUpstreamIdentityClaims(v) + return _u +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (_u *PendingAuthSessionUpdateOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdateOne { + _u.mutation.SetLocalFlowState(v) + return _u +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (_u *PendingAuthSessionUpdateOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetBrowserSessionKey(v) + return _u +} + +// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetBrowserSessionKey(*v) + } + return _u +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetCompletionCodeHash(v) + return _u +} + +// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetCompletionCodeHash(*v) + } + return _u +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetCompletionCodeExpiresAt(v) + return _u +} + +// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetCompletionCodeExpiresAt(*v) + } + return _u +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearCompletionCodeExpiresAt() + return _u +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetEmailVerifiedAt(v) + return _u +} + +// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetEmailVerifiedAt(*v) + } + return _u +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearEmailVerifiedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearEmailVerifiedAt() + return _u +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetPasswordVerifiedAt(v) + return _u +} + +// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetPasswordVerifiedAt(*v) + } + return _u +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearPasswordVerifiedAt() + return _u +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetTotpVerifiedAt(v) + return _u +} + +// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetTotpVerifiedAt(*v) + } + return _u +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearTotpVerifiedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearTotpVerifiedAt() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *PendingAuthSessionUpdateOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// SetConsumedAt sets the "consumed_at" field. +func (_u *PendingAuthSessionUpdateOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetConsumedAt(v) + return _u +} + +// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetConsumedAt(*v) + } + return _u +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearConsumedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearConsumedAt() + return _u +} + +// SetTargetUser sets the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdateOne) SetTargetUser(v *User) *PendingAuthSessionUpdateOne { + return _u.SetTargetUserID(v.ID) +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID. +func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdateOne { + _u.mutation.SetAdoptionDecisionID(id) + return _u +} + +// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdateOne { + if id != nil { + _u = _u.SetAdoptionDecisionID(*id) + } + return _u +} + +// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdateOne { + return _u.SetAdoptionDecisionID(v.ID) +} + +// Mutation returns the PendingAuthSessionMutation object of the builder. +func (_u *PendingAuthSessionUpdateOne) Mutation() *PendingAuthSessionMutation { + return _u.mutation +} + +// ClearTargetUser clears the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdateOne) ClearTargetUser() *PendingAuthSessionUpdateOne { + _u.mutation.ClearTargetUser() + return _u +} + +// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdateOne) ClearAdoptionDecision() *PendingAuthSessionUpdateOne { + _u.mutation.ClearAdoptionDecision() + return _u +} + +// Where appends a list predicates to the PendingAuthSessionUpdate builder. +func (_u *PendingAuthSessionUpdateOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *PendingAuthSessionUpdateOne) Select(field string, fields ...string) *PendingAuthSessionUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated PendingAuthSession entity. +func (_u *PendingAuthSessionUpdateOne) Save(ctx context.Context) (*PendingAuthSession, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *PendingAuthSessionUpdateOne) SaveX(ctx context.Context) *PendingAuthSession { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *PendingAuthSessionUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *PendingAuthSessionUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *PendingAuthSessionUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := pendingauthsession.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *PendingAuthSessionUpdateOne) check() error { + if v, ok := _u.mutation.SessionToken(); ok { + if err := pendingauthsession.SessionTokenValidator(v); err != nil { + return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)} + } + } + if v, ok := _u.mutation.Intent(); ok { + if err := pendingauthsession.IntentValidator(v); err != nil { + return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderType(); ok { + if err := pendingauthsession.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := pendingauthsession.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := pendingauthsession.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)} + } + } + return nil +} + +func (_u *PendingAuthSessionUpdateOne) sqlSave(ctx context.Context) (_node *PendingAuthSession, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "PendingAuthSession.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID) + for _, f := range fields { + if !pendingauthsession.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != pendingauthsession.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.SessionToken(); ok { + _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value) + } + if value, ok := _u.mutation.Intent(); ok { + _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.RedirectTo(); ok { + _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value) + } + if value, ok := _u.mutation.ResolvedEmail(); ok { + _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value) + } + if value, ok := _u.mutation.RegistrationPasswordHash(); ok { + _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value) + } + if value, ok := _u.mutation.UpstreamIdentityClaims(); ok { + _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value) + } + if value, ok := _u.mutation.LocalFlowState(); ok { + _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value) + } + if value, ok := _u.mutation.BrowserSessionKey(); ok { + _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeHash(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value) + } + if _u.mutation.CompletionCodeExpiresAtCleared() { + _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime) + } + if value, ok := _u.mutation.EmailVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value) + } + if _u.mutation.EmailVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.PasswordVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value) + } + if _u.mutation.PasswordVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.TotpVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value) + } + if _u.mutation.TotpVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ConsumedAt(); ok { + _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value) + } + if _u.mutation.ConsumedAtCleared() { + _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime) + } + if _u.mutation.TargetUserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &PendingAuthSession{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{pendingauthsession.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index ef551940..0aa90b90 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -21,6 +21,12 @@ type Announcement func(*sql.Selector) // AnnouncementRead is the predicate function for announcementread builders. type AnnouncementRead func(*sql.Selector) +// AuthIdentity is the predicate function for authidentity builders. +type AuthIdentity func(*sql.Selector) + +// AuthIdentityChannel is the predicate function for authidentitychannel builders. +type AuthIdentityChannel func(*sql.Selector) + // ErrorPassthroughRule is the predicate function for errorpassthroughrule builders. type ErrorPassthroughRule func(*sql.Selector) @@ -30,6 +36,9 @@ type Group func(*sql.Selector) // IdempotencyRecord is the predicate function for idempotencyrecord builders. type IdempotencyRecord func(*sql.Selector) +// IdentityAdoptionDecision is the predicate function for identityadoptiondecision builders. +type IdentityAdoptionDecision func(*sql.Selector) + // PaymentAuditLog is the predicate function for paymentauditlog builders. type PaymentAuditLog func(*sql.Selector) @@ -39,6 +48,9 @@ type PaymentOrder func(*sql.Selector) // PaymentProviderInstance is the predicate function for paymentproviderinstance builders. type PaymentProviderInstance func(*sql.Selector) +// PendingAuthSession is the predicate function for pendingauthsession builders. +type PendingAuthSession func(*sql.Selector) + // PromoCode is the predicate function for promocode builders. type PromoCode func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index fbdd08c7..268e9ddb 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -10,12 +10,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -309,6 +313,120 @@ func init() { announcementreadDescCreatedAt := announcementreadFields[3].Descriptor() // announcementread.DefaultCreatedAt holds the default value on creation for the created_at field. announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time) + authidentityMixin := schema.AuthIdentity{}.Mixin() + authidentityMixinFields0 := authidentityMixin[0].Fields() + _ = authidentityMixinFields0 + authidentityFields := schema.AuthIdentity{}.Fields() + _ = authidentityFields + // authidentityDescCreatedAt is the schema descriptor for created_at field. + authidentityDescCreatedAt := authidentityMixinFields0[0].Descriptor() + // authidentity.DefaultCreatedAt holds the default value on creation for the created_at field. + authidentity.DefaultCreatedAt = authidentityDescCreatedAt.Default.(func() time.Time) + // authidentityDescUpdatedAt is the schema descriptor for updated_at field. + authidentityDescUpdatedAt := authidentityMixinFields0[1].Descriptor() + // authidentity.DefaultUpdatedAt holds the default value on creation for the updated_at field. + authidentity.DefaultUpdatedAt = authidentityDescUpdatedAt.Default.(func() time.Time) + // authidentity.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + authidentity.UpdateDefaultUpdatedAt = authidentityDescUpdatedAt.UpdateDefault.(func() time.Time) + // authidentityDescProviderType is the schema descriptor for provider_type field. + authidentityDescProviderType := authidentityFields[1].Descriptor() + // authidentity.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + authidentity.ProviderTypeValidator = func() func(string) error { + validators := authidentityDescProviderType.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(provider_type string) error { + for _, fn := range fns { + if err := fn(provider_type); err != nil { + return err + } + } + return nil + } + }() + // authidentityDescProviderKey is the schema descriptor for provider_key field. + authidentityDescProviderKey := authidentityFields[2].Descriptor() + // authidentity.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + authidentity.ProviderKeyValidator = authidentityDescProviderKey.Validators[0].(func(string) error) + // authidentityDescProviderSubject is the schema descriptor for provider_subject field. + authidentityDescProviderSubject := authidentityFields[3].Descriptor() + // authidentity.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + authidentity.ProviderSubjectValidator = authidentityDescProviderSubject.Validators[0].(func(string) error) + // authidentityDescMetadata is the schema descriptor for metadata field. + authidentityDescMetadata := authidentityFields[6].Descriptor() + // authidentity.DefaultMetadata holds the default value on creation for the metadata field. + authidentity.DefaultMetadata = authidentityDescMetadata.Default.(func() map[string]interface{}) + authidentitychannelMixin := schema.AuthIdentityChannel{}.Mixin() + authidentitychannelMixinFields0 := authidentitychannelMixin[0].Fields() + _ = authidentitychannelMixinFields0 + authidentitychannelFields := schema.AuthIdentityChannel{}.Fields() + _ = authidentitychannelFields + // authidentitychannelDescCreatedAt is the schema descriptor for created_at field. + authidentitychannelDescCreatedAt := authidentitychannelMixinFields0[0].Descriptor() + // authidentitychannel.DefaultCreatedAt holds the default value on creation for the created_at field. + authidentitychannel.DefaultCreatedAt = authidentitychannelDescCreatedAt.Default.(func() time.Time) + // authidentitychannelDescUpdatedAt is the schema descriptor for updated_at field. + authidentitychannelDescUpdatedAt := authidentitychannelMixinFields0[1].Descriptor() + // authidentitychannel.DefaultUpdatedAt holds the default value on creation for the updated_at field. + authidentitychannel.DefaultUpdatedAt = authidentitychannelDescUpdatedAt.Default.(func() time.Time) + // authidentitychannel.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + authidentitychannel.UpdateDefaultUpdatedAt = authidentitychannelDescUpdatedAt.UpdateDefault.(func() time.Time) + // authidentitychannelDescProviderType is the schema descriptor for provider_type field. + authidentitychannelDescProviderType := authidentitychannelFields[1].Descriptor() + // authidentitychannel.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + authidentitychannel.ProviderTypeValidator = func() func(string) error { + validators := authidentitychannelDescProviderType.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(provider_type string) error { + for _, fn := range fns { + if err := fn(provider_type); err != nil { + return err + } + } + return nil + } + }() + // authidentitychannelDescProviderKey is the schema descriptor for provider_key field. + authidentitychannelDescProviderKey := authidentitychannelFields[2].Descriptor() + // authidentitychannel.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + authidentitychannel.ProviderKeyValidator = authidentitychannelDescProviderKey.Validators[0].(func(string) error) + // authidentitychannelDescChannel is the schema descriptor for channel field. + authidentitychannelDescChannel := authidentitychannelFields[3].Descriptor() + // authidentitychannel.ChannelValidator is a validator for the "channel" field. It is called by the builders before save. + authidentitychannel.ChannelValidator = func() func(string) error { + validators := authidentitychannelDescChannel.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(channel string) error { + for _, fn := range fns { + if err := fn(channel); err != nil { + return err + } + } + return nil + } + }() + // authidentitychannelDescChannelAppID is the schema descriptor for channel_app_id field. + authidentitychannelDescChannelAppID := authidentitychannelFields[4].Descriptor() + // authidentitychannel.ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save. + authidentitychannel.ChannelAppIDValidator = authidentitychannelDescChannelAppID.Validators[0].(func(string) error) + // authidentitychannelDescChannelSubject is the schema descriptor for channel_subject field. + authidentitychannelDescChannelSubject := authidentitychannelFields[5].Descriptor() + // authidentitychannel.ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save. + authidentitychannel.ChannelSubjectValidator = authidentitychannelDescChannelSubject.Validators[0].(func(string) error) + // authidentitychannelDescMetadata is the schema descriptor for metadata field. + authidentitychannelDescMetadata := authidentitychannelFields[6].Descriptor() + // authidentitychannel.DefaultMetadata holds the default value on creation for the metadata field. + authidentitychannel.DefaultMetadata = authidentitychannelDescMetadata.Default.(func() map[string]interface{}) errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin() errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields() _ = errorpassthroughruleMixinFields0 @@ -512,6 +630,33 @@ func init() { idempotencyrecordDescErrorReason := idempotencyrecordFields[6].Descriptor() // idempotencyrecord.ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save. idempotencyrecord.ErrorReasonValidator = idempotencyrecordDescErrorReason.Validators[0].(func(string) error) + identityadoptiondecisionMixin := schema.IdentityAdoptionDecision{}.Mixin() + identityadoptiondecisionMixinFields0 := identityadoptiondecisionMixin[0].Fields() + _ = identityadoptiondecisionMixinFields0 + identityadoptiondecisionFields := schema.IdentityAdoptionDecision{}.Fields() + _ = identityadoptiondecisionFields + // identityadoptiondecisionDescCreatedAt is the schema descriptor for created_at field. + identityadoptiondecisionDescCreatedAt := identityadoptiondecisionMixinFields0[0].Descriptor() + // identityadoptiondecision.DefaultCreatedAt holds the default value on creation for the created_at field. + identityadoptiondecision.DefaultCreatedAt = identityadoptiondecisionDescCreatedAt.Default.(func() time.Time) + // identityadoptiondecisionDescUpdatedAt is the schema descriptor for updated_at field. + identityadoptiondecisionDescUpdatedAt := identityadoptiondecisionMixinFields0[1].Descriptor() + // identityadoptiondecision.DefaultUpdatedAt holds the default value on creation for the updated_at field. + identityadoptiondecision.DefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.Default.(func() time.Time) + // identityadoptiondecision.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + identityadoptiondecision.UpdateDefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.UpdateDefault.(func() time.Time) + // identityadoptiondecisionDescAdoptDisplayName is the schema descriptor for adopt_display_name field. + identityadoptiondecisionDescAdoptDisplayName := identityadoptiondecisionFields[2].Descriptor() + // identityadoptiondecision.DefaultAdoptDisplayName holds the default value on creation for the adopt_display_name field. + identityadoptiondecision.DefaultAdoptDisplayName = identityadoptiondecisionDescAdoptDisplayName.Default.(bool) + // identityadoptiondecisionDescAdoptAvatar is the schema descriptor for adopt_avatar field. + identityadoptiondecisionDescAdoptAvatar := identityadoptiondecisionFields[3].Descriptor() + // identityadoptiondecision.DefaultAdoptAvatar holds the default value on creation for the adopt_avatar field. + identityadoptiondecision.DefaultAdoptAvatar = identityadoptiondecisionDescAdoptAvatar.Default.(bool) + // identityadoptiondecisionDescDecidedAt is the schema descriptor for decided_at field. + identityadoptiondecisionDescDecidedAt := identityadoptiondecisionFields[4].Descriptor() + // identityadoptiondecision.DefaultDecidedAt holds the default value on creation for the decided_at field. + identityadoptiondecision.DefaultDecidedAt = identityadoptiondecisionDescDecidedAt.Default.(func() time.Time) paymentauditlogFields := schema.PaymentAuditLog{}.Fields() _ = paymentauditlogFields // paymentauditlogDescOrderID is the schema descriptor for order_id field. @@ -682,6 +827,113 @@ func init() { paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time) // paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. paymentproviderinstance.UpdateDefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.UpdateDefault.(func() time.Time) + pendingauthsessionMixin := schema.PendingAuthSession{}.Mixin() + pendingauthsessionMixinFields0 := pendingauthsessionMixin[0].Fields() + _ = pendingauthsessionMixinFields0 + pendingauthsessionFields := schema.PendingAuthSession{}.Fields() + _ = pendingauthsessionFields + // pendingauthsessionDescCreatedAt is the schema descriptor for created_at field. + pendingauthsessionDescCreatedAt := pendingauthsessionMixinFields0[0].Descriptor() + // pendingauthsession.DefaultCreatedAt holds the default value on creation for the created_at field. + pendingauthsession.DefaultCreatedAt = pendingauthsessionDescCreatedAt.Default.(func() time.Time) + // pendingauthsessionDescUpdatedAt is the schema descriptor for updated_at field. + pendingauthsessionDescUpdatedAt := pendingauthsessionMixinFields0[1].Descriptor() + // pendingauthsession.DefaultUpdatedAt holds the default value on creation for the updated_at field. + pendingauthsession.DefaultUpdatedAt = pendingauthsessionDescUpdatedAt.Default.(func() time.Time) + // pendingauthsession.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + pendingauthsession.UpdateDefaultUpdatedAt = pendingauthsessionDescUpdatedAt.UpdateDefault.(func() time.Time) + // pendingauthsessionDescSessionToken is the schema descriptor for session_token field. + pendingauthsessionDescSessionToken := pendingauthsessionFields[0].Descriptor() + // pendingauthsession.SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save. + pendingauthsession.SessionTokenValidator = func() func(string) error { + validators := pendingauthsessionDescSessionToken.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(session_token string) error { + for _, fn := range fns { + if err := fn(session_token); err != nil { + return err + } + } + return nil + } + }() + // pendingauthsessionDescIntent is the schema descriptor for intent field. + pendingauthsessionDescIntent := pendingauthsessionFields[1].Descriptor() + // pendingauthsession.IntentValidator is a validator for the "intent" field. It is called by the builders before save. + pendingauthsession.IntentValidator = func() func(string) error { + validators := pendingauthsessionDescIntent.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(intent string) error { + for _, fn := range fns { + if err := fn(intent); err != nil { + return err + } + } + return nil + } + }() + // pendingauthsessionDescProviderType is the schema descriptor for provider_type field. + pendingauthsessionDescProviderType := pendingauthsessionFields[2].Descriptor() + // pendingauthsession.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + pendingauthsession.ProviderTypeValidator = func() func(string) error { + validators := pendingauthsessionDescProviderType.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(provider_type string) error { + for _, fn := range fns { + if err := fn(provider_type); err != nil { + return err + } + } + return nil + } + }() + // pendingauthsessionDescProviderKey is the schema descriptor for provider_key field. + pendingauthsessionDescProviderKey := pendingauthsessionFields[3].Descriptor() + // pendingauthsession.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + pendingauthsession.ProviderKeyValidator = pendingauthsessionDescProviderKey.Validators[0].(func(string) error) + // pendingauthsessionDescProviderSubject is the schema descriptor for provider_subject field. + pendingauthsessionDescProviderSubject := pendingauthsessionFields[4].Descriptor() + // pendingauthsession.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + pendingauthsession.ProviderSubjectValidator = pendingauthsessionDescProviderSubject.Validators[0].(func(string) error) + // pendingauthsessionDescRedirectTo is the schema descriptor for redirect_to field. + pendingauthsessionDescRedirectTo := pendingauthsessionFields[6].Descriptor() + // pendingauthsession.DefaultRedirectTo holds the default value on creation for the redirect_to field. + pendingauthsession.DefaultRedirectTo = pendingauthsessionDescRedirectTo.Default.(string) + // pendingauthsessionDescResolvedEmail is the schema descriptor for resolved_email field. + pendingauthsessionDescResolvedEmail := pendingauthsessionFields[7].Descriptor() + // pendingauthsession.DefaultResolvedEmail holds the default value on creation for the resolved_email field. + pendingauthsession.DefaultResolvedEmail = pendingauthsessionDescResolvedEmail.Default.(string) + // pendingauthsessionDescRegistrationPasswordHash is the schema descriptor for registration_password_hash field. + pendingauthsessionDescRegistrationPasswordHash := pendingauthsessionFields[8].Descriptor() + // pendingauthsession.DefaultRegistrationPasswordHash holds the default value on creation for the registration_password_hash field. + pendingauthsession.DefaultRegistrationPasswordHash = pendingauthsessionDescRegistrationPasswordHash.Default.(string) + // pendingauthsessionDescUpstreamIdentityClaims is the schema descriptor for upstream_identity_claims field. + pendingauthsessionDescUpstreamIdentityClaims := pendingauthsessionFields[9].Descriptor() + // pendingauthsession.DefaultUpstreamIdentityClaims holds the default value on creation for the upstream_identity_claims field. + pendingauthsession.DefaultUpstreamIdentityClaims = pendingauthsessionDescUpstreamIdentityClaims.Default.(func() map[string]interface{}) + // pendingauthsessionDescLocalFlowState is the schema descriptor for local_flow_state field. + pendingauthsessionDescLocalFlowState := pendingauthsessionFields[10].Descriptor() + // pendingauthsession.DefaultLocalFlowState holds the default value on creation for the local_flow_state field. + pendingauthsession.DefaultLocalFlowState = pendingauthsessionDescLocalFlowState.Default.(func() map[string]interface{}) + // pendingauthsessionDescBrowserSessionKey is the schema descriptor for browser_session_key field. + pendingauthsessionDescBrowserSessionKey := pendingauthsessionFields[11].Descriptor() + // pendingauthsession.DefaultBrowserSessionKey holds the default value on creation for the browser_session_key field. + pendingauthsession.DefaultBrowserSessionKey = pendingauthsessionDescBrowserSessionKey.Default.(string) + // pendingauthsessionDescCompletionCodeHash is the schema descriptor for completion_code_hash field. + pendingauthsessionDescCompletionCodeHash := pendingauthsessionFields[12].Descriptor() + // pendingauthsession.DefaultCompletionCodeHash holds the default value on creation for the completion_code_hash field. + pendingauthsession.DefaultCompletionCodeHash = pendingauthsessionDescCompletionCodeHash.Default.(string) promocodeFields := schema.PromoCode{}.Fields() _ = promocodeFields // promocodeDescCode is the schema descriptor for code field. @@ -1297,20 +1549,26 @@ func init() { userDescTotpEnabled := userFields[9].Descriptor() // user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field. user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool) + // userDescSignupSource is the schema descriptor for signup_source field. + userDescSignupSource := userFields[11].Descriptor() + // user.DefaultSignupSource holds the default value on creation for the signup_source field. + user.DefaultSignupSource = userDescSignupSource.Default.(string) + // user.SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save. + user.SignupSourceValidator = userDescSignupSource.Validators[0].(func(string) error) // userDescBalanceNotifyEnabled is the schema descriptor for balance_notify_enabled field. - userDescBalanceNotifyEnabled := userFields[11].Descriptor() + userDescBalanceNotifyEnabled := userFields[14].Descriptor() // user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field. user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool) // userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field. - userDescBalanceNotifyThresholdType := userFields[12].Descriptor() + userDescBalanceNotifyThresholdType := userFields[15].Descriptor() // user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field. user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string) // userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field. - userDescBalanceNotifyExtraEmails := userFields[14].Descriptor() + userDescBalanceNotifyExtraEmails := userFields[17].Descriptor() // user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field. user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string) // userDescTotalRecharged is the schema descriptor for total_recharged field. - userDescTotalRecharged := userFields[15].Descriptor() + userDescTotalRecharged := userFields[18].Descriptor() // user.DefaultTotalRecharged holds the default value on creation for the total_recharged field. user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64) userallowedgroupFields := schema.UserAllowedGroup{}.Fields() diff --git a/backend/ent/schema/auth_identity.go b/backend/ent/schema/auth_identity.go new file mode 100644 index 00000000..e4b9ac90 --- /dev/null +++ b/backend/ent/schema/auth_identity.go @@ -0,0 +1,93 @@ +package schema + +import ( + "fmt" + + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +var authProviderTypes = map[string]struct{}{ + "email": {}, + "linuxdo": {}, + "oidc": {}, + "wechat": {}, +} + +func validateAuthProviderType(value string) error { + if _, ok := authProviderTypes[value]; ok { + return nil + } + return fmt.Errorf("invalid auth provider type %q", value) +} + +// AuthIdentity stores the canonical login identity for an account. +type AuthIdentity struct { + ent.Schema +} + +func (AuthIdentity) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "auth_identities"}, + } +} + +func (AuthIdentity) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (AuthIdentity) Fields() []ent.Field { + return []ent.Field{ + field.Int64("user_id"), + field.String("provider_type"). + MaxLen(20). + NotEmpty(). + Validate(validateAuthProviderType), + field.String("provider_key"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("provider_subject"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Time("verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.String("issuer"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.JSON("metadata", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + } +} + +func (AuthIdentity) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("user", User.Type). + Ref("auth_identities"). + Field("user_id"). + Required(). + Unique(), + edge.To("channels", AuthIdentityChannel.Type), + edge.To("adoption_decisions", IdentityAdoptionDecision.Type), + } +} + +func (AuthIdentity) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("provider_type", "provider_key", "provider_subject").Unique(), + index.Fields("user_id"), + index.Fields("user_id", "provider_type"), + } +} diff --git a/backend/ent/schema/auth_identity_channel.go b/backend/ent/schema/auth_identity_channel.go new file mode 100644 index 00000000..69f2ad02 --- /dev/null +++ b/backend/ent/schema/auth_identity_channel.go @@ -0,0 +1,72 @@ +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// AuthIdentityChannel stores channel-scoped identifiers for a canonical identity. +type AuthIdentityChannel struct { + ent.Schema +} + +func (AuthIdentityChannel) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "auth_identity_channels"}, + } +} + +func (AuthIdentityChannel) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (AuthIdentityChannel) Fields() []ent.Field { + return []ent.Field{ + field.Int64("identity_id"), + field.String("provider_type"). + MaxLen(20). + NotEmpty(). + Validate(validateAuthProviderType), + field.String("provider_key"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("channel"). + MaxLen(20). + NotEmpty(), + field.String("channel_app_id"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("channel_subject"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.JSON("metadata", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + } +} + +func (AuthIdentityChannel) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("identity", AuthIdentity.Type). + Ref("channels"). + Field("identity_id"). + Required(). + Unique(), + } +} + +func (AuthIdentityChannel) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("provider_type", "provider_key", "channel", "channel_app_id", "channel_subject").Unique(), + index.Fields("identity_id"), + } +} diff --git a/backend/ent/schema/auth_identity_schema_test.go b/backend/ent/schema/auth_identity_schema_test.go new file mode 100644 index 00000000..de55dd69 --- /dev/null +++ b/backend/ent/schema/auth_identity_schema_test.go @@ -0,0 +1,124 @@ +package schema + +import ( + "testing" + + "entgo.io/ent/entc/load" + "github.com/stretchr/testify/require" +) + +func TestAuthIdentityFoundationSchemas(t *testing.T) { + spec, err := (&load.Config{Path: "."}).Load() + require.NoError(t, err) + + schemas := map[string]*load.Schema{} + for _, schema := range spec.Schemas { + schemas[schema.Name] = schema + } + + authIdentity := requireSchema(t, schemas, "AuthIdentity") + requireSchemaFields(t, authIdentity, + "user_id", + "provider_type", + "provider_key", + "provider_subject", + "verified_at", + "issuer", + "metadata", + ) + requireHasUniqueIndex(t, authIdentity, "provider_type", "provider_key", "provider_subject") + + authIdentityChannel := requireSchema(t, schemas, "AuthIdentityChannel") + requireSchemaFields(t, authIdentityChannel, + "identity_id", + "provider_type", + "provider_key", + "channel", + "channel_app_id", + "channel_subject", + "metadata", + ) + requireHasUniqueIndex(t, authIdentityChannel, "provider_type", "provider_key", "channel", "channel_app_id", "channel_subject") + + pendingAuthSession := requireSchema(t, schemas, "PendingAuthSession") + requireSchemaFields(t, pendingAuthSession, + "intent", + "provider_type", + "provider_key", + "provider_subject", + "target_user_id", + "redirect_to", + "resolved_email", + "registration_password_hash", + "upstream_identity_claims", + "local_flow_state", + "browser_session_key", + "completion_code_hash", + "completion_code_expires_at", + "email_verified_at", + "password_verified_at", + "totp_verified_at", + "expires_at", + "consumed_at", + ) + + adoptionDecision := requireSchema(t, schemas, "IdentityAdoptionDecision") + requireSchemaFields(t, adoptionDecision, + "pending_auth_session_id", + "identity_id", + "adopt_display_name", + "adopt_avatar", + "decided_at", + ) + requireHasUniqueIndex(t, adoptionDecision, "pending_auth_session_id") + + userSchema := requireSchema(t, schemas, "User") + requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at") +} + +func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema { + t.Helper() + + schema, ok := schemas[name] + require.True(t, ok, "schema %s should exist", name) + return schema +} + +func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) { + t.Helper() + + fields := map[string]struct{}{} + for _, field := range schema.Fields { + fields[field.Name] = struct{}{} + } + + for _, name := range names { + _, ok := fields[name] + require.True(t, ok, "schema %s should include field %s", schema.Name, name) + } +} + +func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) { + t.Helper() + + for _, index := range schema.Indexes { + if !index.Unique { + continue + } + if len(index.Fields) != len(fields) { + continue + } + match := true + for i := range fields { + if index.Fields[i] != fields[i] { + match = false + break + } + } + if match { + return + } + } + + require.Failf(t, "missing unique index", "schema %s should include unique index on %v", schema.Name, fields) +} diff --git a/backend/ent/schema/identity_adoption_decision.go b/backend/ent/schema/identity_adoption_decision.go new file mode 100644 index 00000000..9fdd26fb --- /dev/null +++ b/backend/ent/schema/identity_adoption_decision.go @@ -0,0 +1,70 @@ +package schema + +import ( + "time" + + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// IdentityAdoptionDecision stores the one-time profile adoption choice captured during a pending auth flow. +type IdentityAdoptionDecision struct { + ent.Schema +} + +func (IdentityAdoptionDecision) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "identity_adoption_decisions"}, + } +} + +func (IdentityAdoptionDecision) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (IdentityAdoptionDecision) Fields() []ent.Field { + return []ent.Field{ + field.Int64("pending_auth_session_id"), + field.Int64("identity_id"). + Optional(). + Nillable(), + field.Bool("adopt_display_name"). + Default(false), + field.Bool("adopt_avatar"). + Default(false), + field.Time("decided_at"). + Immutable(). + Default(time.Now). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + } +} + +func (IdentityAdoptionDecision) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("pending_auth_session", PendingAuthSession.Type). + Ref("adoption_decision"). + Field("pending_auth_session_id"). + Required(). + Unique(), + edge.From("identity", AuthIdentity.Type). + Ref("adoption_decisions"). + Field("identity_id"). + Unique(), + } +} + +func (IdentityAdoptionDecision) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("pending_auth_session_id").Unique(), + index.Fields("identity_id"), + } +} diff --git a/backend/ent/schema/pending_auth_session.go b/backend/ent/schema/pending_auth_session.go new file mode 100644 index 00000000..91341d49 --- /dev/null +++ b/backend/ent/schema/pending_auth_session.go @@ -0,0 +1,134 @@ +package schema + +import ( + "fmt" + + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +var pendingAuthIntents = map[string]struct{}{ + "login": {}, + "bind_current_user": {}, + "adopt_existing_user_by_email": {}, +} + +func validatePendingAuthIntent(value string) error { + if _, ok := pendingAuthIntents[value]; ok { + return nil + } + return fmt.Errorf("invalid pending auth intent %q", value) +} + +// PendingAuthSession stores a short-lived post-auth decision session. +type PendingAuthSession struct { + ent.Schema +} + +func (PendingAuthSession) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "pending_auth_sessions"}, + } +} + +func (PendingAuthSession) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (PendingAuthSession) Fields() []ent.Field { + return []ent.Field{ + field.String("session_token"). + MaxLen(255). + NotEmpty(), + field.String("intent"). + MaxLen(40). + NotEmpty(). + Validate(validatePendingAuthIntent), + field.String("provider_type"). + MaxLen(20). + NotEmpty(). + Validate(validateAuthProviderType), + field.String("provider_key"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("provider_subject"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Int64("target_user_id"). + Optional(). + Nillable(), + field.String("redirect_to"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("resolved_email"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("registration_password_hash"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.JSON("upstream_identity_claims", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + field.JSON("local_flow_state", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + field.String("browser_session_key"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("completion_code_hash"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Time("completion_code_expires_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("email_verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("password_verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("totp_verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("expires_at"). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("consumed_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + } +} + +func (PendingAuthSession) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("target_user", User.Type). + Ref("pending_auth_sessions"). + Field("target_user_id"). + Unique(), + edge.To("adoption_decision", IdentityAdoptionDecision.Type). + Unique(), + } +} + +func (PendingAuthSession) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("session_token").Unique(), + index.Fields("target_user_id"), + index.Fields("expires_at"), + index.Fields("provider_type", "provider_key", "provider_subject"), + index.Fields("completion_code_hash"), + } +} diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index ef52e985..bb58d9e3 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -72,6 +72,17 @@ func (User) Fields() []ent.Field { field.Time("totp_enabled_at"). Optional(). Nillable(), + field.String("signup_source"). + MaxLen(20). + Default("email"), + field.Time("last_login_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("last_active_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), // 余额不足通知 field.Bool("balance_notify_enabled"). @@ -104,6 +115,8 @@ func (User) Edges() []ent.Edge { edge.To("attribute_values", UserAttributeValue.Type), edge.To("promo_code_usages", PromoCodeUsage.Type), edge.To("payment_orders", PaymentOrder.Type), + edge.To("auth_identities", AuthIdentity.Type), + edge.To("pending_auth_sessions", PendingAuthSession.Type), } } diff --git a/backend/ent/tx.go b/backend/ent/tx.go index bb3139d5..bde3e35b 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -24,18 +24,26 @@ type Tx struct { Announcement *AnnouncementClient // AnnouncementRead is the client for interacting with the AnnouncementRead builders. AnnouncementRead *AnnouncementReadClient + // AuthIdentity is the client for interacting with the AuthIdentity builders. + AuthIdentity *AuthIdentityClient + // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders. + AuthIdentityChannel *AuthIdentityChannelClient // ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders. ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. IdempotencyRecord *IdempotencyRecordClient + // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders. + IdentityAdoptionDecision *IdentityAdoptionDecisionClient // PaymentAuditLog is the client for interacting with the PaymentAuditLog builders. PaymentAuditLog *PaymentAuditLogClient // PaymentOrder is the client for interacting with the PaymentOrder builders. PaymentOrder *PaymentOrderClient // PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders. PaymentProviderInstance *PaymentProviderInstanceClient + // PendingAuthSession is the client for interacting with the PendingAuthSession builders. + PendingAuthSession *PendingAuthSessionClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -202,12 +210,16 @@ func (tx *Tx) init() { tx.AccountGroup = NewAccountGroupClient(tx.config) tx.Announcement = NewAnnouncementClient(tx.config) tx.AnnouncementRead = NewAnnouncementReadClient(tx.config) + tx.AuthIdentity = NewAuthIdentityClient(tx.config) + tx.AuthIdentityChannel = NewAuthIdentityChannelClient(tx.config) tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config) tx.Group = NewGroupClient(tx.config) tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config) + tx.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(tx.config) tx.PaymentAuditLog = NewPaymentAuditLogClient(tx.config) tx.PaymentOrder = NewPaymentOrderClient(tx.config) tx.PaymentProviderInstance = NewPaymentProviderInstanceClient(tx.config) + tx.PendingAuthSession = NewPendingAuthSessionClient(tx.config) tx.PromoCode = NewPromoCodeClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) tx.Proxy = NewProxyClient(tx.config) diff --git a/backend/ent/user.go b/backend/ent/user.go index 9fa91f74..66f33623 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -45,6 +45,12 @@ type User struct { TotpEnabled bool `json:"totp_enabled,omitempty"` // TotpEnabledAt holds the value of the "totp_enabled_at" field. TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"` + // SignupSource holds the value of the "signup_source" field. + SignupSource string `json:"signup_source,omitempty"` + // LastLoginAt holds the value of the "last_login_at" field. + LastLoginAt *time.Time `json:"last_login_at,omitempty"` + // LastActiveAt holds the value of the "last_active_at" field. + LastActiveAt *time.Time `json:"last_active_at,omitempty"` // BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field. BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"` // BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field. @@ -83,11 +89,15 @@ type UserEdges struct { PromoCodeUsages []*PromoCodeUsage `json:"promo_code_usages,omitempty"` // PaymentOrders holds the value of the payment_orders edge. PaymentOrders []*PaymentOrder `json:"payment_orders,omitempty"` + // AuthIdentities holds the value of the auth_identities edge. + AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"` + // PendingAuthSessions holds the value of the pending_auth_sessions edge. + PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"` // UserAllowedGroups holds the value of the user_allowed_groups edge. UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [11]bool + loadedTypes [13]bool } // APIKeysOrErr returns the APIKeys value or an error if the edge @@ -180,10 +190,28 @@ func (e UserEdges) PaymentOrdersOrErr() ([]*PaymentOrder, error) { return nil, &NotLoadedError{edge: "payment_orders"} } +// AuthIdentitiesOrErr returns the AuthIdentities value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) AuthIdentitiesOrErr() ([]*AuthIdentity, error) { + if e.loadedTypes[10] { + return e.AuthIdentities, nil + } + return nil, &NotLoadedError{edge: "auth_identities"} +} + +// PendingAuthSessionsOrErr returns the PendingAuthSessions value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) { + if e.loadedTypes[11] { + return e.PendingAuthSessions, nil + } + return nil, &NotLoadedError{edge: "pending_auth_sessions"} +} + // UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge // was not loaded in eager-loading. func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) { - if e.loadedTypes[10] { + if e.loadedTypes[12] { return e.UserAllowedGroups, nil } return nil, &NotLoadedError{edge: "user_allowed_groups"} @@ -200,9 +228,9 @@ func (*User) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case user.FieldID, user.FieldConcurrency: values[i] = new(sql.NullInt64) - case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails: + case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails: values[i] = new(sql.NullString) - case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt: + case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt, user.FieldLastLoginAt, user.FieldLastActiveAt: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -312,6 +340,26 @@ func (_m *User) assignValues(columns []string, values []any) error { _m.TotpEnabledAt = new(time.Time) *_m.TotpEnabledAt = value.Time } + case user.FieldSignupSource: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field signup_source", values[i]) + } else if value.Valid { + _m.SignupSource = value.String + } + case user.FieldLastLoginAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field last_login_at", values[i]) + } else if value.Valid { + _m.LastLoginAt = new(time.Time) + *_m.LastLoginAt = value.Time + } + case user.FieldLastActiveAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field last_active_at", values[i]) + } else if value.Valid { + _m.LastActiveAt = new(time.Time) + *_m.LastActiveAt = value.Time + } case user.FieldBalanceNotifyEnabled: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i]) @@ -406,6 +454,16 @@ func (_m *User) QueryPaymentOrders() *PaymentOrderQuery { return NewUserClient(_m.config).QueryPaymentOrders(_m) } +// QueryAuthIdentities queries the "auth_identities" edge of the User entity. +func (_m *User) QueryAuthIdentities() *AuthIdentityQuery { + return NewUserClient(_m.config).QueryAuthIdentities(_m) +} + +// QueryPendingAuthSessions queries the "pending_auth_sessions" edge of the User entity. +func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery { + return NewUserClient(_m.config).QueryPendingAuthSessions(_m) +} + // QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity. func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery { return NewUserClient(_m.config).QueryUserAllowedGroups(_m) @@ -482,6 +540,19 @@ func (_m *User) String() string { builder.WriteString(v.Format(time.ANSIC)) } builder.WriteString(", ") + builder.WriteString("signup_source=") + builder.WriteString(_m.SignupSource) + builder.WriteString(", ") + if v := _m.LastLoginAt; v != nil { + builder.WriteString("last_login_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.LastActiveAt; v != nil { + builder.WriteString("last_active_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") builder.WriteString("balance_notify_enabled=") builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled)) builder.WriteString(", ") diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index d88a3a38..567e3b14 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -43,6 +43,12 @@ const ( FieldTotpEnabled = "totp_enabled" // FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database. FieldTotpEnabledAt = "totp_enabled_at" + // FieldSignupSource holds the string denoting the signup_source field in the database. + FieldSignupSource = "signup_source" + // FieldLastLoginAt holds the string denoting the last_login_at field in the database. + FieldLastLoginAt = "last_login_at" + // FieldLastActiveAt holds the string denoting the last_active_at field in the database. + FieldLastActiveAt = "last_active_at" // FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database. FieldBalanceNotifyEnabled = "balance_notify_enabled" // FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database. @@ -73,6 +79,10 @@ const ( EdgePromoCodeUsages = "promo_code_usages" // EdgePaymentOrders holds the string denoting the payment_orders edge name in mutations. EdgePaymentOrders = "payment_orders" + // EdgeAuthIdentities holds the string denoting the auth_identities edge name in mutations. + EdgeAuthIdentities = "auth_identities" + // EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations. + EdgePendingAuthSessions = "pending_auth_sessions" // EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations. EdgeUserAllowedGroups = "user_allowed_groups" // Table holds the table name of the user in the database. @@ -145,6 +155,20 @@ const ( PaymentOrdersInverseTable = "payment_orders" // PaymentOrdersColumn is the table column denoting the payment_orders relation/edge. PaymentOrdersColumn = "user_id" + // AuthIdentitiesTable is the table that holds the auth_identities relation/edge. + AuthIdentitiesTable = "auth_identities" + // AuthIdentitiesInverseTable is the table name for the AuthIdentity entity. + // It exists in this package in order to avoid circular dependency with the "authidentity" package. + AuthIdentitiesInverseTable = "auth_identities" + // AuthIdentitiesColumn is the table column denoting the auth_identities relation/edge. + AuthIdentitiesColumn = "user_id" + // PendingAuthSessionsTable is the table that holds the pending_auth_sessions relation/edge. + PendingAuthSessionsTable = "pending_auth_sessions" + // PendingAuthSessionsInverseTable is the table name for the PendingAuthSession entity. + // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package. + PendingAuthSessionsInverseTable = "pending_auth_sessions" + // PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge. + PendingAuthSessionsColumn = "target_user_id" // UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge. UserAllowedGroupsTable = "user_allowed_groups" // UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity. @@ -171,6 +195,9 @@ var Columns = []string{ FieldTotpSecretEncrypted, FieldTotpEnabled, FieldTotpEnabledAt, + FieldSignupSource, + FieldLastLoginAt, + FieldLastActiveAt, FieldBalanceNotifyEnabled, FieldBalanceNotifyThresholdType, FieldBalanceNotifyThreshold, @@ -232,6 +259,10 @@ var ( DefaultNotes string // DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field. DefaultTotpEnabled bool + // DefaultSignupSource holds the default value on creation for the "signup_source" field. + DefaultSignupSource string + // SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save. + SignupSourceValidator func(string) error // DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field. DefaultBalanceNotifyEnabled bool // DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field. @@ -320,6 +351,21 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc() } +// BySignupSource orders the results by the signup_source field. +func BySignupSource(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSignupSource, opts...).ToFunc() +} + +// ByLastLoginAt orders the results by the last_login_at field. +func ByLastLoginAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastLoginAt, opts...).ToFunc() +} + +// ByLastActiveAt orders the results by the last_active_at field. +func ByLastActiveAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastActiveAt, opts...).ToFunc() +} + // ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field. func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc() @@ -485,6 +531,34 @@ func ByPaymentOrders(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { } } +// ByAuthIdentitiesCount orders the results by auth_identities count. +func ByAuthIdentitiesCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAuthIdentitiesStep(), opts...) + } +} + +// ByAuthIdentities orders the results by auth_identities terms. +func ByAuthIdentities(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAuthIdentitiesStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByPendingAuthSessionsCount orders the results by pending_auth_sessions count. +func ByPendingAuthSessionsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newPendingAuthSessionsStep(), opts...) + } +} + +// ByPendingAuthSessions orders the results by pending_auth_sessions terms. +func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + // ByUserAllowedGroupsCount orders the results by user_allowed_groups count. func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -568,6 +642,20 @@ func newPaymentOrdersStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.O2M, false, PaymentOrdersTable, PaymentOrdersColumn), ) } +func newAuthIdentitiesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AuthIdentitiesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn), + ) +} +func newPendingAuthSessionsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(PendingAuthSessionsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn), + ) +} func newUserAllowedGroupsStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index 2788aa7a..cbcfcc26 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -125,6 +125,21 @@ func TotpEnabledAt(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v)) } +// SignupSource applies equality check predicate on the "signup_source" field. It's identical to SignupSourceEQ. +func SignupSource(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldSignupSource, v)) +} + +// LastLoginAt applies equality check predicate on the "last_login_at" field. It's identical to LastLoginAtEQ. +func LastLoginAt(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastLoginAt, v)) +} + +// LastActiveAt applies equality check predicate on the "last_active_at" field. It's identical to LastActiveAtEQ. +func LastActiveAt(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastActiveAt, v)) +} + // BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ. func BalanceNotifyEnabled(v bool) predicate.User { return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v)) @@ -885,6 +900,171 @@ func TotpEnabledAtNotNil() predicate.User { return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt)) } +// SignupSourceEQ applies the EQ predicate on the "signup_source" field. +func SignupSourceEQ(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldSignupSource, v)) +} + +// SignupSourceNEQ applies the NEQ predicate on the "signup_source" field. +func SignupSourceNEQ(v string) predicate.User { + return predicate.User(sql.FieldNEQ(FieldSignupSource, v)) +} + +// SignupSourceIn applies the In predicate on the "signup_source" field. +func SignupSourceIn(vs ...string) predicate.User { + return predicate.User(sql.FieldIn(FieldSignupSource, vs...)) +} + +// SignupSourceNotIn applies the NotIn predicate on the "signup_source" field. +func SignupSourceNotIn(vs ...string) predicate.User { + return predicate.User(sql.FieldNotIn(FieldSignupSource, vs...)) +} + +// SignupSourceGT applies the GT predicate on the "signup_source" field. +func SignupSourceGT(v string) predicate.User { + return predicate.User(sql.FieldGT(FieldSignupSource, v)) +} + +// SignupSourceGTE applies the GTE predicate on the "signup_source" field. +func SignupSourceGTE(v string) predicate.User { + return predicate.User(sql.FieldGTE(FieldSignupSource, v)) +} + +// SignupSourceLT applies the LT predicate on the "signup_source" field. +func SignupSourceLT(v string) predicate.User { + return predicate.User(sql.FieldLT(FieldSignupSource, v)) +} + +// SignupSourceLTE applies the LTE predicate on the "signup_source" field. +func SignupSourceLTE(v string) predicate.User { + return predicate.User(sql.FieldLTE(FieldSignupSource, v)) +} + +// SignupSourceContains applies the Contains predicate on the "signup_source" field. +func SignupSourceContains(v string) predicate.User { + return predicate.User(sql.FieldContains(FieldSignupSource, v)) +} + +// SignupSourceHasPrefix applies the HasPrefix predicate on the "signup_source" field. +func SignupSourceHasPrefix(v string) predicate.User { + return predicate.User(sql.FieldHasPrefix(FieldSignupSource, v)) +} + +// SignupSourceHasSuffix applies the HasSuffix predicate on the "signup_source" field. +func SignupSourceHasSuffix(v string) predicate.User { + return predicate.User(sql.FieldHasSuffix(FieldSignupSource, v)) +} + +// SignupSourceEqualFold applies the EqualFold predicate on the "signup_source" field. +func SignupSourceEqualFold(v string) predicate.User { + return predicate.User(sql.FieldEqualFold(FieldSignupSource, v)) +} + +// SignupSourceContainsFold applies the ContainsFold predicate on the "signup_source" field. +func SignupSourceContainsFold(v string) predicate.User { + return predicate.User(sql.FieldContainsFold(FieldSignupSource, v)) +} + +// LastLoginAtEQ applies the EQ predicate on the "last_login_at" field. +func LastLoginAtEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastLoginAt, v)) +} + +// LastLoginAtNEQ applies the NEQ predicate on the "last_login_at" field. +func LastLoginAtNEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldNEQ(FieldLastLoginAt, v)) +} + +// LastLoginAtIn applies the In predicate on the "last_login_at" field. +func LastLoginAtIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldIn(FieldLastLoginAt, vs...)) +} + +// LastLoginAtNotIn applies the NotIn predicate on the "last_login_at" field. +func LastLoginAtNotIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldNotIn(FieldLastLoginAt, vs...)) +} + +// LastLoginAtGT applies the GT predicate on the "last_login_at" field. +func LastLoginAtGT(v time.Time) predicate.User { + return predicate.User(sql.FieldGT(FieldLastLoginAt, v)) +} + +// LastLoginAtGTE applies the GTE predicate on the "last_login_at" field. +func LastLoginAtGTE(v time.Time) predicate.User { + return predicate.User(sql.FieldGTE(FieldLastLoginAt, v)) +} + +// LastLoginAtLT applies the LT predicate on the "last_login_at" field. +func LastLoginAtLT(v time.Time) predicate.User { + return predicate.User(sql.FieldLT(FieldLastLoginAt, v)) +} + +// LastLoginAtLTE applies the LTE predicate on the "last_login_at" field. +func LastLoginAtLTE(v time.Time) predicate.User { + return predicate.User(sql.FieldLTE(FieldLastLoginAt, v)) +} + +// LastLoginAtIsNil applies the IsNil predicate on the "last_login_at" field. +func LastLoginAtIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldLastLoginAt)) +} + +// LastLoginAtNotNil applies the NotNil predicate on the "last_login_at" field. +func LastLoginAtNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldLastLoginAt)) +} + +// LastActiveAtEQ applies the EQ predicate on the "last_active_at" field. +func LastActiveAtEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastActiveAt, v)) +} + +// LastActiveAtNEQ applies the NEQ predicate on the "last_active_at" field. +func LastActiveAtNEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldNEQ(FieldLastActiveAt, v)) +} + +// LastActiveAtIn applies the In predicate on the "last_active_at" field. +func LastActiveAtIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldIn(FieldLastActiveAt, vs...)) +} + +// LastActiveAtNotIn applies the NotIn predicate on the "last_active_at" field. +func LastActiveAtNotIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldNotIn(FieldLastActiveAt, vs...)) +} + +// LastActiveAtGT applies the GT predicate on the "last_active_at" field. +func LastActiveAtGT(v time.Time) predicate.User { + return predicate.User(sql.FieldGT(FieldLastActiveAt, v)) +} + +// LastActiveAtGTE applies the GTE predicate on the "last_active_at" field. +func LastActiveAtGTE(v time.Time) predicate.User { + return predicate.User(sql.FieldGTE(FieldLastActiveAt, v)) +} + +// LastActiveAtLT applies the LT predicate on the "last_active_at" field. +func LastActiveAtLT(v time.Time) predicate.User { + return predicate.User(sql.FieldLT(FieldLastActiveAt, v)) +} + +// LastActiveAtLTE applies the LTE predicate on the "last_active_at" field. +func LastActiveAtLTE(v time.Time) predicate.User { + return predicate.User(sql.FieldLTE(FieldLastActiveAt, v)) +} + +// LastActiveAtIsNil applies the IsNil predicate on the "last_active_at" field. +func LastActiveAtIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldLastActiveAt)) +} + +// LastActiveAtNotNil applies the NotNil predicate on the "last_active_at" field. +func LastActiveAtNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldLastActiveAt)) +} + // BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field. func BalanceNotifyEnabledEQ(v bool) predicate.User { return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v)) @@ -1345,6 +1525,52 @@ func HasPaymentOrdersWith(preds ...predicate.PaymentOrder) predicate.User { }) } +// HasAuthIdentities applies the HasEdge predicate on the "auth_identities" edge. +func HasAuthIdentities() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAuthIdentitiesWith applies the HasEdge predicate on the "auth_identities" edge with a given conditions (other predicates). +func HasAuthIdentitiesWith(preds ...predicate.AuthIdentity) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newAuthIdentitiesStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasPendingAuthSessions applies the HasEdge predicate on the "pending_auth_sessions" edge. +func HasPendingAuthSessions() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasPendingAuthSessionsWith applies the HasEdge predicate on the "pending_auth_sessions" edge with a given conditions (other predicates). +func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newPendingAuthSessionsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge. func HasUserAllowedGroups() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index fbc64f9c..db95e813 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -13,8 +13,10 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -211,6 +213,48 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate { return _c } +// SetSignupSource sets the "signup_source" field. +func (_c *UserCreate) SetSignupSource(v string) *UserCreate { + _c.mutation.SetSignupSource(v) + return _c +} + +// SetNillableSignupSource sets the "signup_source" field if the given value is not nil. +func (_c *UserCreate) SetNillableSignupSource(v *string) *UserCreate { + if v != nil { + _c.SetSignupSource(*v) + } + return _c +} + +// SetLastLoginAt sets the "last_login_at" field. +func (_c *UserCreate) SetLastLoginAt(v time.Time) *UserCreate { + _c.mutation.SetLastLoginAt(v) + return _c +} + +// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil. +func (_c *UserCreate) SetNillableLastLoginAt(v *time.Time) *UserCreate { + if v != nil { + _c.SetLastLoginAt(*v) + } + return _c +} + +// SetLastActiveAt sets the "last_active_at" field. +func (_c *UserCreate) SetLastActiveAt(v time.Time) *UserCreate { + _c.mutation.SetLastActiveAt(v) + return _c +} + +// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil. +func (_c *UserCreate) SetNillableLastActiveAt(v *time.Time) *UserCreate { + if v != nil { + _c.SetLastActiveAt(*v) + } + return _c +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate { _c.mutation.SetBalanceNotifyEnabled(v) @@ -431,6 +475,36 @@ func (_c *UserCreate) AddPaymentOrders(v ...*PaymentOrder) *UserCreate { return _c.AddPaymentOrderIDs(ids...) } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs. +func (_c *UserCreate) AddAuthIdentityIDs(ids ...int64) *UserCreate { + _c.mutation.AddAuthIdentityIDs(ids...) + return _c +} + +// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity. +func (_c *UserCreate) AddAuthIdentities(v ...*AuthIdentity) *UserCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddAuthIdentityIDs(ids...) +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (_c *UserCreate) AddPendingAuthSessionIDs(ids ...int64) *UserCreate { + _c.mutation.AddPendingAuthSessionIDs(ids...) + return _c +} + +// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddPendingAuthSessionIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_c *UserCreate) Mutation() *UserMutation { return _c.mutation @@ -510,6 +584,10 @@ func (_c *UserCreate) defaults() error { v := user.DefaultTotpEnabled _c.mutation.SetTotpEnabled(v) } + if _, ok := _c.mutation.SignupSource(); !ok { + v := user.DefaultSignupSource + _c.mutation.SetSignupSource(v) + } if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok { v := user.DefaultBalanceNotifyEnabled _c.mutation.SetBalanceNotifyEnabled(v) @@ -589,6 +667,14 @@ func (_c *UserCreate) check() error { if _, ok := _c.mutation.TotpEnabled(); !ok { return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)} } + if _, ok := _c.mutation.SignupSource(); !ok { + return &ValidationError{Name: "signup_source", err: errors.New(`ent: missing required field "User.signup_source"`)} + } + if v, ok := _c.mutation.SignupSource(); ok { + if err := user.SignupSourceValidator(v); err != nil { + return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)} + } + } if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok { return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)} } @@ -684,6 +770,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value) _node.TotpEnabledAt = &value } + if value, ok := _c.mutation.SignupSource(); ok { + _spec.SetField(user.FieldSignupSource, field.TypeString, value) + _node.SignupSource = value + } + if value, ok := _c.mutation.LastLoginAt(); ok { + _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value) + _node.LastLoginAt = &value + } + if value, ok := _c.mutation.LastActiveAt(); ok { + _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value) + _node.LastActiveAt = &value + } if value, ok := _c.mutation.BalanceNotifyEnabled(); ok { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) _node.BalanceNotifyEnabled = value @@ -868,6 +966,38 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { } _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.AuthIdentitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } @@ -1106,6 +1236,54 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert { return u } +// SetSignupSource sets the "signup_source" field. +func (u *UserUpsert) SetSignupSource(v string) *UserUpsert { + u.Set(user.FieldSignupSource, v) + return u +} + +// UpdateSignupSource sets the "signup_source" field to the value that was provided on create. +func (u *UserUpsert) UpdateSignupSource() *UserUpsert { + u.SetExcluded(user.FieldSignupSource) + return u +} + +// SetLastLoginAt sets the "last_login_at" field. +func (u *UserUpsert) SetLastLoginAt(v time.Time) *UserUpsert { + u.Set(user.FieldLastLoginAt, v) + return u +} + +// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create. +func (u *UserUpsert) UpdateLastLoginAt() *UserUpsert { + u.SetExcluded(user.FieldLastLoginAt) + return u +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (u *UserUpsert) ClearLastLoginAt() *UserUpsert { + u.SetNull(user.FieldLastLoginAt) + return u +} + +// SetLastActiveAt sets the "last_active_at" field. +func (u *UserUpsert) SetLastActiveAt(v time.Time) *UserUpsert { + u.Set(user.FieldLastActiveAt, v) + return u +} + +// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create. +func (u *UserUpsert) UpdateLastActiveAt() *UserUpsert { + u.SetExcluded(user.FieldLastActiveAt) + return u +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (u *UserUpsert) ClearLastActiveAt() *UserUpsert { + u.SetNull(user.FieldLastActiveAt) + return u +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert { u.Set(user.FieldBalanceNotifyEnabled, v) @@ -1446,6 +1624,62 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne { }) } +// SetSignupSource sets the "signup_source" field. +func (u *UserUpsertOne) SetSignupSource(v string) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetSignupSource(v) + }) +} + +// UpdateSignupSource sets the "signup_source" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateSignupSource() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateSignupSource() + }) +} + +// SetLastLoginAt sets the "last_login_at" field. +func (u *UserUpsertOne) SetLastLoginAt(v time.Time) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetLastLoginAt(v) + }) +} + +// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateLastLoginAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateLastLoginAt() + }) +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (u *UserUpsertOne) ClearLastLoginAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearLastLoginAt() + }) +} + +// SetLastActiveAt sets the "last_active_at" field. +func (u *UserUpsertOne) SetLastActiveAt(v time.Time) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetLastActiveAt(v) + }) +} + +// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateLastActiveAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateLastActiveAt() + }) +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (u *UserUpsertOne) ClearLastActiveAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearLastActiveAt() + }) +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne { return u.Update(func(s *UserUpsert) { @@ -1965,6 +2199,62 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk { }) } +// SetSignupSource sets the "signup_source" field. +func (u *UserUpsertBulk) SetSignupSource(v string) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetSignupSource(v) + }) +} + +// UpdateSignupSource sets the "signup_source" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateSignupSource() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateSignupSource() + }) +} + +// SetLastLoginAt sets the "last_login_at" field. +func (u *UserUpsertBulk) SetLastLoginAt(v time.Time) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetLastLoginAt(v) + }) +} + +// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateLastLoginAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateLastLoginAt() + }) +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (u *UserUpsertBulk) ClearLastLoginAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearLastLoginAt() + }) +} + +// SetLastActiveAt sets the "last_active_at" field. +func (u *UserUpsertBulk) SetLastActiveAt(v time.Time) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetLastActiveAt(v) + }) +} + +// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateLastActiveAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateLastActiveAt() + }) +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (u *UserUpsertBulk) ClearLastActiveAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearLastActiveAt() + }) +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk { return u.Update(func(s *UserUpsert) { diff --git a/backend/ent/user_query.go b/backend/ent/user_query.go index 113d87ac..f1ee5cfe 100644 --- a/backend/ent/user_query.go +++ b/backend/ent/user_query.go @@ -15,8 +15,10 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/redeemcode" @@ -44,6 +46,8 @@ type UserQuery struct { withAttributeValues *UserAttributeValueQuery withPromoCodeUsages *PromoCodeUsageQuery withPaymentOrders *PaymentOrderQuery + withAuthIdentities *AuthIdentityQuery + withPendingAuthSessions *PendingAuthSessionQuery withUserAllowedGroups *UserAllowedGroupQuery modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). @@ -302,6 +306,50 @@ func (_q *UserQuery) QueryPaymentOrders() *PaymentOrderQuery { return query } +// QueryAuthIdentities chains the current query on the "auth_identities" edge. +func (_q *UserQuery) QueryAuthIdentities() *AuthIdentityQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryPendingAuthSessions chains the current query on the "pending_auth_sessions" edge. +func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge. func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: _q.config}).Query() @@ -526,6 +574,8 @@ func (_q *UserQuery) Clone() *UserQuery { withAttributeValues: _q.withAttributeValues.Clone(), withPromoCodeUsages: _q.withPromoCodeUsages.Clone(), withPaymentOrders: _q.withPaymentOrders.Clone(), + withAuthIdentities: _q.withAuthIdentities.Clone(), + withPendingAuthSessions: _q.withPendingAuthSessions.Clone(), withUserAllowedGroups: _q.withUserAllowedGroups.Clone(), // clone intermediate query. sql: _q.sql.Clone(), @@ -643,6 +693,28 @@ func (_q *UserQuery) WithPaymentOrders(opts ...func(*PaymentOrderQuery)) *UserQu return _q } +// WithAuthIdentities tells the query-builder to eager-load the nodes that are connected to +// the "auth_identities" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserQuery) WithAuthIdentities(opts ...func(*AuthIdentityQuery)) *UserQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAuthIdentities = query + return _q +} + +// WithPendingAuthSessions tells the query-builder to eager-load the nodes that are connected to +// the "pending_auth_sessions" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQuery)) *UserQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withPendingAuthSessions = query + return _q +} + // WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to // the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge. func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery { @@ -732,7 +804,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e var ( nodes = []*User{} _spec = _q.querySpec() - loadedTypes = [11]bool{ + loadedTypes = [13]bool{ _q.withAPIKeys != nil, _q.withRedeemCodes != nil, _q.withSubscriptions != nil, @@ -743,6 +815,8 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e _q.withAttributeValues != nil, _q.withPromoCodeUsages != nil, _q.withPaymentOrders != nil, + _q.withAuthIdentities != nil, + _q.withPendingAuthSessions != nil, _q.withUserAllowedGroups != nil, } ) @@ -839,6 +913,22 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e return nil, err } } + if query := _q.withAuthIdentities; query != nil { + if err := _q.loadAuthIdentities(ctx, query, nodes, + func(n *User) { n.Edges.AuthIdentities = []*AuthIdentity{} }, + func(n *User, e *AuthIdentity) { n.Edges.AuthIdentities = append(n.Edges.AuthIdentities, e) }); err != nil { + return nil, err + } + } + if query := _q.withPendingAuthSessions; query != nil { + if err := _q.loadPendingAuthSessions(ctx, query, nodes, + func(n *User) { n.Edges.PendingAuthSessions = []*PendingAuthSession{} }, + func(n *User, e *PendingAuthSession) { + n.Edges.PendingAuthSessions = append(n.Edges.PendingAuthSessions, e) + }); err != nil { + return nil, err + } + } if query := _q.withUserAllowedGroups; query != nil { if err := _q.loadUserAllowedGroups(ctx, query, nodes, func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} }, @@ -1186,6 +1276,69 @@ func (_q *UserQuery) loadPaymentOrders(ctx context.Context, query *PaymentOrderQ } return nil } +func (_q *UserQuery) loadAuthIdentities(ctx context.Context, query *AuthIdentityQuery, nodes []*User, init func(*User), assign func(*User, *AuthIdentity)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(authidentity.FieldUserID) + } + query.Where(predicate.AuthIdentity(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.AuthIdentitiesColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *PendingAuthSessionQuery, nodes []*User, init func(*User), assign func(*User, *PendingAuthSession)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(pendingauthsession.FieldTargetUserID) + } + query.Where(predicate.PendingAuthSession(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.PendingAuthSessionsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.TargetUserID + if fk == nil { + return fmt.Errorf(`foreign-key "target_user_id" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "target_user_id" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error { fks := make([]driver.Value, 0, len(nodes)) nodeids := make(map[int64]*User) diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index 6b355247..677eeb6b 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -13,8 +13,10 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/redeemcode" @@ -243,6 +245,60 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate { return _u } +// SetSignupSource sets the "signup_source" field. +func (_u *UserUpdate) SetSignupSource(v string) *UserUpdate { + _u.mutation.SetSignupSource(v) + return _u +} + +// SetNillableSignupSource sets the "signup_source" field if the given value is not nil. +func (_u *UserUpdate) SetNillableSignupSource(v *string) *UserUpdate { + if v != nil { + _u.SetSignupSource(*v) + } + return _u +} + +// SetLastLoginAt sets the "last_login_at" field. +func (_u *UserUpdate) SetLastLoginAt(v time.Time) *UserUpdate { + _u.mutation.SetLastLoginAt(v) + return _u +} + +// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil. +func (_u *UserUpdate) SetNillableLastLoginAt(v *time.Time) *UserUpdate { + if v != nil { + _u.SetLastLoginAt(*v) + } + return _u +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (_u *UserUpdate) ClearLastLoginAt() *UserUpdate { + _u.mutation.ClearLastLoginAt() + return _u +} + +// SetLastActiveAt sets the "last_active_at" field. +func (_u *UserUpdate) SetLastActiveAt(v time.Time) *UserUpdate { + _u.mutation.SetLastActiveAt(v) + return _u +} + +// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil. +func (_u *UserUpdate) SetNillableLastActiveAt(v *time.Time) *UserUpdate { + if v != nil { + _u.SetLastActiveAt(*v) + } + return _u +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (_u *UserUpdate) ClearLastActiveAt() *UserUpdate { + _u.mutation.ClearLastActiveAt() + return _u +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate { _u.mutation.SetBalanceNotifyEnabled(v) @@ -483,6 +539,36 @@ func (_u *UserUpdate) AddPaymentOrders(v ...*PaymentOrder) *UserUpdate { return _u.AddPaymentOrderIDs(ids...) } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs. +func (_u *UserUpdate) AddAuthIdentityIDs(ids ...int64) *UserUpdate { + _u.mutation.AddAuthIdentityIDs(ids...) + return _u +} + +// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdate) AddAuthIdentities(v ...*AuthIdentity) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAuthIdentityIDs(ids...) +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (_u *UserUpdate) AddPendingAuthSessionIDs(ids ...int64) *UserUpdate { + _u.mutation.AddPendingAuthSessionIDs(ids...) + return _u +} + +// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddPendingAuthSessionIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdate) Mutation() *UserMutation { return _u.mutation @@ -698,6 +784,48 @@ func (_u *UserUpdate) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdate { return _u.RemovePaymentOrderIDs(ids...) } +// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdate) ClearAuthIdentities() *UserUpdate { + _u.mutation.ClearAuthIdentities() + return _u +} + +// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs. +func (_u *UserUpdate) RemoveAuthIdentityIDs(ids ...int64) *UserUpdate { + _u.mutation.RemoveAuthIdentityIDs(ids...) + return _u +} + +// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities. +func (_u *UserUpdate) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAuthIdentityIDs(ids...) +} + +// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdate) ClearPendingAuthSessions() *UserUpdate { + _u.mutation.ClearPendingAuthSessions() + return _u +} + +// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs. +func (_u *UserUpdate) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdate { + _u.mutation.RemovePendingAuthSessionIDs(ids...) + return _u +} + +// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities. +func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemovePendingAuthSessionIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *UserUpdate) Save(ctx context.Context) (int, error) { if err := _u.defaults(); err != nil { @@ -767,6 +895,11 @@ func (_u *UserUpdate) check() error { return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} } } + if v, ok := _u.mutation.SignupSource(); ok { + if err := user.SignupSourceValidator(v); err != nil { + return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)} + } + } return nil } @@ -836,6 +969,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SignupSource(); ok { + _spec.SetField(user.FieldSignupSource, field.TypeString, value) + } + if value, ok := _u.mutation.LastLoginAt(); ok { + _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value) + } + if _u.mutation.LastLoginAtCleared() { + _spec.ClearField(user.FieldLastLoginAt, field.TypeTime) + } + if value, ok := _u.mutation.LastActiveAt(); ok { + _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value) + } + if _u.mutation.LastActiveAtCleared() { + _spec.ClearField(user.FieldLastActiveAt, field.TypeTime) + } if value, ok := _u.mutation.BalanceNotifyEnabled(); ok { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) } @@ -1322,6 +1470,96 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{user.Label} @@ -1548,6 +1786,60 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne { return _u } +// SetSignupSource sets the "signup_source" field. +func (_u *UserUpdateOne) SetSignupSource(v string) *UserUpdateOne { + _u.mutation.SetSignupSource(v) + return _u +} + +// SetNillableSignupSource sets the "signup_source" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableSignupSource(v *string) *UserUpdateOne { + if v != nil { + _u.SetSignupSource(*v) + } + return _u +} + +// SetLastLoginAt sets the "last_login_at" field. +func (_u *UserUpdateOne) SetLastLoginAt(v time.Time) *UserUpdateOne { + _u.mutation.SetLastLoginAt(v) + return _u +} + +// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableLastLoginAt(v *time.Time) *UserUpdateOne { + if v != nil { + _u.SetLastLoginAt(*v) + } + return _u +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (_u *UserUpdateOne) ClearLastLoginAt() *UserUpdateOne { + _u.mutation.ClearLastLoginAt() + return _u +} + +// SetLastActiveAt sets the "last_active_at" field. +func (_u *UserUpdateOne) SetLastActiveAt(v time.Time) *UserUpdateOne { + _u.mutation.SetLastActiveAt(v) + return _u +} + +// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableLastActiveAt(v *time.Time) *UserUpdateOne { + if v != nil { + _u.SetLastActiveAt(*v) + } + return _u +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (_u *UserUpdateOne) ClearLastActiveAt() *UserUpdateOne { + _u.mutation.ClearLastActiveAt() + return _u +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne { _u.mutation.SetBalanceNotifyEnabled(v) @@ -1788,6 +2080,36 @@ func (_u *UserUpdateOne) AddPaymentOrders(v ...*PaymentOrder) *UserUpdateOne { return _u.AddPaymentOrderIDs(ids...) } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs. +func (_u *UserUpdateOne) AddAuthIdentityIDs(ids ...int64) *UserUpdateOne { + _u.mutation.AddAuthIdentityIDs(ids...) + return _u +} + +// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdateOne) AddAuthIdentities(v ...*AuthIdentity) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAuthIdentityIDs(ids...) +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (_u *UserUpdateOne) AddPendingAuthSessionIDs(ids ...int64) *UserUpdateOne { + _u.mutation.AddPendingAuthSessionIDs(ids...) + return _u +} + +// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddPendingAuthSessionIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdateOne) Mutation() *UserMutation { return _u.mutation @@ -2003,6 +2325,48 @@ func (_u *UserUpdateOne) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdateOne return _u.RemovePaymentOrderIDs(ids...) } +// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdateOne) ClearAuthIdentities() *UserUpdateOne { + _u.mutation.ClearAuthIdentities() + return _u +} + +// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs. +func (_u *UserUpdateOne) RemoveAuthIdentityIDs(ids ...int64) *UserUpdateOne { + _u.mutation.RemoveAuthIdentityIDs(ids...) + return _u +} + +// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities. +func (_u *UserUpdateOne) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAuthIdentityIDs(ids...) +} + +// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdateOne) ClearPendingAuthSessions() *UserUpdateOne { + _u.mutation.ClearPendingAuthSessions() + return _u +} + +// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs. +func (_u *UserUpdateOne) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdateOne { + _u.mutation.RemovePendingAuthSessionIDs(ids...) + return _u +} + +// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities. +func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemovePendingAuthSessionIDs(ids...) +} + // Where appends a list predicates to the UserUpdate builder. func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne { _u.mutation.Where(ps...) @@ -2085,6 +2449,11 @@ func (_u *UserUpdateOne) check() error { return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} } } + if v, ok := _u.mutation.SignupSource(); ok { + if err := user.SignupSourceValidator(v); err != nil { + return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)} + } + } return nil } @@ -2171,6 +2540,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SignupSource(); ok { + _spec.SetField(user.FieldSignupSource, field.TypeString, value) + } + if value, ok := _u.mutation.LastLoginAt(); ok { + _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value) + } + if _u.mutation.LastLoginAtCleared() { + _spec.ClearField(user.FieldLastLoginAt, field.TypeTime) + } + if value, ok := _u.mutation.LastActiveAt(); ok { + _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value) + } + if _u.mutation.LastActiveAtCleared() { + _spec.ClearField(user.FieldLastActiveAt, field.TypeTime) + } if value, ok := _u.mutation.BalanceNotifyEnabled(); ok { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) } @@ -2657,6 +3041,96 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &User{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index dd9a4e58..6136e9ea 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -1608,6 +1608,9 @@ func (c *Config) Validate() error { return fmt.Errorf("security.csp.policy is required when CSP is enabled") } if c.LinuxDo.Enabled { + if !c.LinuxDo.UsePKCE { + return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.enabled=true") + } if strings.TrimSpace(c.LinuxDo.ClientID) == "" { return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true") } @@ -1629,9 +1632,6 @@ func (c *Config) Validate() error { default: return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") } - if method == "none" && !c.LinuxDo.UsePKCE { - return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none") - } if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" { return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") @@ -1663,6 +1663,12 @@ func (c *Config) Validate() error { warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL) } if c.OIDC.Enabled { + if !c.OIDC.UsePKCE { + return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.enabled=true") + } + if !c.OIDC.ValidateIDToken { + return fmt.Errorf("oidc_connect.validate_id_token must be true when oidc_connect.enabled=true") + } if strings.TrimSpace(c.OIDC.ClientID) == "" { return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true") } @@ -1685,9 +1691,6 @@ func (c *Config) Validate() error { default: return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") } - if method == "none" && !c.OIDC.UsePKCE { - return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.token_auth_method=none") - } if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.OIDC.ClientSecret) == "" { return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index bec0f126..fe5c7928 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -73,6 +73,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + authSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } // Check if ops monitoring is enabled (respects config.ops.enabled) opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context()) @@ -93,7 +98,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { paymentCfg = &service.PaymentConfig{} } - response.Success(c, dto.SystemSettings{ + payload := dto.SystemSettings{ RegistrationEnabled: settings.RegistrationEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled, RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, @@ -200,7 +205,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow, PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit, PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode, - }) + } + response.Success(c, systemSettingsResponseData(payload, authSourceDefaults)) } // UpdateSettingsRequest 更新设置请求 @@ -276,9 +282,30 @@ type UpdateSettingsRequest struct { CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"` // 默认配置 - DefaultConcurrency int `json:"default_concurrency"` - DefaultBalance float64 `json:"default_balance"` - DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` + DefaultConcurrency int `json:"default_concurrency"` + DefaultBalance float64 `json:"default_balance"` + DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` + AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"` + AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"` + AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"` + AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"` + AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"` + AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"` + AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"` + AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"` + AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"` + AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"` + AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"` + AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"` + AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"` + AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"` + AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"` + AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"` + AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"` + AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"` + AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"` + AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"` + ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"` // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -357,6 +384,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + previousAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } // 验证参数 if req.DefaultConcurrency < 1 { @@ -381,6 +413,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { req.SMTPPort = 587 } req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions) + req.AuthSourceDefaultEmailSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultEmailSubscriptions) + req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions) + req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions) + req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions) // SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置 // 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置 @@ -538,25 +574,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.BadRequest(c, "OIDC scopes must contain openid") return } + if !req.OIDCConnectUsePKCE { + response.BadRequest(c, "OIDC PKCE must be enabled") + return + } + if !req.OIDCConnectValidateIDToken { + response.BadRequest(c, "OIDC ID Token validation must be enabled") + return + } switch req.OIDCConnectTokenAuthMethod { case "", "client_secret_post", "client_secret_basic", "none": default: response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none") return } - if req.OIDCConnectTokenAuthMethod == "none" && !req.OIDCConnectUsePKCE { - response.BadRequest(c, "OIDC PKCE must be enabled when token_auth_method=none") - return - } if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 { response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600") return } - if req.OIDCConnectValidateIDToken { - if req.OIDCConnectAllowedSigningAlgs == "" { - response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true") - return - } + if req.OIDCConnectAllowedSigningAlgs == "" { + response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true") + return } if req.OIDCConnectJWKSURL != "" { if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil { @@ -933,6 +971,41 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + authSourceDefaults := &service.AuthSourceDefaultSettings{ + Email: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultEmailConcurrency, previousAuthSourceDefaults.Email.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind), + }, + LinuxDo: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultLinuxDoConcurrency, previousAuthSourceDefaults.LinuxDo.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind), + }, + OIDC: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultOIDCConcurrency, previousAuthSourceDefaults.OIDC.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind), + }, + WeChat: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultWeChatConcurrency, previousAuthSourceDefaults.WeChat.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind), + }, + ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup), + } + if err := h.settingService.UpdateAuthSourceDefaultSettings(c.Request.Context(), authSourceDefaults); err != nil { + response.ErrorFrom(c, err) + return + } // Update payment configuration (integrated into system settings). // Skip if no payment fields were provided (prevents accidental wipe). @@ -977,6 +1050,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions)) for _, sub := range updatedSettings.DefaultSubscriptions { updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{ @@ -994,7 +1072,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { updatedPaymentCfg = &service.PaymentConfig{} } - response.Success(c, dto.SystemSettings{ + payload := dto.SystemSettings{ RegistrationEnabled: updatedSettings.RegistrationEnabled, EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist, @@ -1100,7 +1178,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow, PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit, PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode, - }) + } + response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults)) } // hasPaymentFields returns true if any payment-related field was explicitly provided. @@ -1412,6 +1491,84 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto return normalized } +func normalizeOptionalDefaultSubscriptions(input *[]dto.DefaultSubscriptionSetting) *[]dto.DefaultSubscriptionSetting { + if input == nil { + return nil + } + normalized := normalizeDefaultSubscriptions(*input) + return &normalized +} + +func float64ValueOrDefault(value *float64, fallback float64) float64 { + if value == nil { + return fallback + } + return *value +} + +func intValueOrDefault(value *int, fallback int) int { + if value == nil { + return fallback + } + return *value +} + +func boolValueOrDefault(value *bool, fallback bool) bool { + if value == nil { + return fallback + } + return *value +} + +func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting, fallback []service.DefaultSubscriptionSetting) []service.DefaultSubscriptionSetting { + if input == nil { + return fallback + } + result := make([]service.DefaultSubscriptionSetting, 0, len(*input)) + for _, item := range *input { + result = append(result, service.DefaultSubscriptionSetting{ + GroupID: item.GroupID, + ValidityDays: item.ValidityDays, + }) + } + return result +} + +func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any { + data := make(map[string]any) + raw, err := json.Marshal(settings) + if err == nil { + _ = json.Unmarshal(raw, &data) + } + if authSourceDefaults == nil { + authSourceDefaults = &service.AuthSourceDefaultSettings{} + } + + data["auth_source_default_email_balance"] = authSourceDefaults.Email.Balance + data["auth_source_default_email_concurrency"] = authSourceDefaults.Email.Concurrency + data["auth_source_default_email_subscriptions"] = authSourceDefaults.Email.Subscriptions + data["auth_source_default_email_grant_on_signup"] = authSourceDefaults.Email.GrantOnSignup + data["auth_source_default_email_grant_on_first_bind"] = authSourceDefaults.Email.GrantOnFirstBind + data["auth_source_default_linuxdo_balance"] = authSourceDefaults.LinuxDo.Balance + data["auth_source_default_linuxdo_concurrency"] = authSourceDefaults.LinuxDo.Concurrency + data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions + data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup + data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind + data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance + data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency + data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions + data["auth_source_default_oidc_grant_on_signup"] = authSourceDefaults.OIDC.GrantOnSignup + data["auth_source_default_oidc_grant_on_first_bind"] = authSourceDefaults.OIDC.GrantOnFirstBind + data["auth_source_default_wechat_balance"] = authSourceDefaults.WeChat.Balance + data["auth_source_default_wechat_concurrency"] = authSourceDefaults.WeChat.Concurrency + data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions + data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup + data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind + data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup + + return data +} + func equalStringSlice(a, b []string) bool { if len(a) != len(b) { return false diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go new file mode 100644 index 00000000..b26fa447 --- /dev/null +++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go @@ -0,0 +1,149 @@ +package admin + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type settingHandlerRepoStub struct { + values map[string]string + lastUpdates map[string]string +} + +func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *settingHandlerRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingHandlerRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + s.lastUpdates = make(map[string]string, len(settings)) + for key, value := range settings { + s.lastUpdates[key] = value + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + } + return nil +} + +func (s *settingHandlerRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for key, value := range s.values { + out[key] = value + } + return out, nil +} + +func (s *settingHandlerRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "9.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "8", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`, + service.SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil) + + handler.GetSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, 9.5, data["auth_source_default_email_balance"]) + require.Equal(t, float64(8), data["auth_source_default_email_concurrency"]) + require.Equal(t, true, data["force_email_on_third_party_signup"]) + + subscriptions, ok := data["auth_source_default_email_subscriptions"].([]any) + require.True(t, ok) + require.Len(t, subscriptions, 1) +} + +func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "false", + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "9.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "8", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false", + service.SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "registration_enabled": true, + "promo_code_enabled": true, + "auth_source_default_email_balance": 12.75, + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "12.75000000", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance]) + require.Equal(t, "8", repo.values[service.SettingKeyAuthSourceDefaultEmailConcurrency]) + require.Equal(t, `[{"group_id":31,"validity_days":15}]`, repo.values[service.SettingKeyAuthSourceDefaultEmailSubscriptions]) + require.Equal(t, "true", repo.values[service.SettingKeyForceEmailOnThirdPartySignup]) + + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, 12.75, data["auth_source_default_email_balance"]) + require.Equal(t, float64(8), data["auth_source_default_email_concurrency"]) + require.Equal(t, true, data["force_email_on_third_party_signup"]) +} diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 2f182642..b0edcf5a 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -219,7 +219,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { } // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if errors.Is(err, service.ErrOAuthInvitationRequired) { if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ @@ -262,6 +262,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { ProviderKey: "linuxdo", ProviderSubject: subject, }, + TargetUserID: &user.ID, ResolvedEmail: email, RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, @@ -287,7 +288,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { } type completeLinuxDoOAuthRequest struct { - InvitationCode string `json:"invitation_code" binding:"required"` + InvitationCode string `json:"invitation_code" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } // CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating @@ -335,11 +338,23 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { return } - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) if err != nil { response.ErrorFrom(c, err) return } + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + AdoptDisplayName: req.AdoptDisplayName, + AdoptAvatar: req.AdoptAvatar, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie) diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index 90bc10d1..661c0da0 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -1,10 +1,21 @@ package handler import ( + "bytes" + "context" + "net/http" + "net/http/httptest" "strings" "testing" + "time" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -110,3 +121,79 @@ func TestSingleLineStripsWhitespace(t *testing.T) { require.Equal(t, "hello world", singleLine("hello\r\nworld")) require.Equal(t, "", singleLine("\n\t\r")) } + +func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("linuxdo-complete-session"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-subject-1"). + SetResolvedEmail("linuxdo-subject-1@linuxdo-connect.invalid"). + SetBrowserSessionKey("linuxdo-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "LinuxDo Display", + "suggested_avatar_url": "https://cdn.example/linuxdo.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptAvatar: true, + }) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser")}) + c.Request = req + + handler.CompleteLinuxDoOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "LinuxDo Display", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("linuxdo-subject-1"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "LinuxDo Display", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/linuxdo.png", identity.Metadata["avatar_url"]) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index a758c0b9..da8ac858 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -1,10 +1,17 @@ package handler import ( + "context" + "errors" + "io" "net/http" "net/url" "strings" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -26,6 +33,7 @@ const ( type oauthPendingSessionPayload struct { Intent string Identity service.PendingAuthIdentityKey + TargetUserID *int64 ResolvedEmail string RedirectTo string BrowserSessionKey string @@ -33,6 +41,11 @@ type oauthPendingSessionPayload struct { CompletionResponse map[string]any } +type oauthAdoptionDecisionRequest struct { + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) { if h == nil || h.authService == nil || h.authService.EntClient() == nil { return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") @@ -125,6 +138,7 @@ func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPen session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{ Intent: strings.TrimSpace(payload.Intent), Identity: payload.Identity, + TargetUserID: payload.TargetUserID, ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail), RedirectTo: strings.TrimSpace(payload.RedirectTo), BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey), @@ -175,6 +189,291 @@ func pendingSessionWantsInvitation(payload map[string]any) bool { return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required") } +func (r oauthAdoptionDecisionRequest) hasDecision() bool { + return r.AdoptDisplayName != nil || r.AdoptAvatar != nil +} + +func (r oauthAdoptionDecisionRequest) toServiceInput(sessionID int64) service.PendingIdentityAdoptionDecisionInput { + input := service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: sessionID, + } + if r.AdoptDisplayName != nil { + input.AdoptDisplayName = *r.AdoptDisplayName + } + if r.AdoptAvatar != nil { + input.AdoptAvatar = *r.AdoptAvatar + } + return input +} + +func bindOptionalOAuthAdoptionDecision(c *gin.Context) (oauthAdoptionDecisionRequest, error) { + var req oauthAdoptionDecisionRequest + if c == nil || c.Request == nil || c.Request.Body == nil { + return req, nil + } + if err := c.ShouldBindJSON(&req); err != nil { + if errors.Is(err, io.EOF) { + return req, nil + } + return req, err + } + return req, nil +} + +func persistPendingOAuthAdoptionDecision( + c *gin.Context, + svc *service.AuthPendingIdentityService, + sessionID int64, + req oauthAdoptionDecisionRequest, +) error { + if !req.hasDecision() { + return nil + } + if svc == nil { + return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + if _, err := svc.UpsertAdoptionDecision(c.Request.Context(), req.toServiceInput(sessionID)); err != nil { + return infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err) + } + return nil +} + +func cloneOAuthMetadata(values map[string]any) map[string]any { + if len(values) == 0 { + return map[string]any{} + } + cloned := make(map[string]any, len(values)) + for key, value := range values { + cloned[key] = value + } + return cloned +} + +func normalizeAdoptedOAuthDisplayName(value string) string { + value = strings.TrimSpace(value) + if len([]rune(value)) > 100 { + value = string([]rune(value)[:100]) + } + return value +} + +func (h *AuthHandler) entClient() *dbent.Client { + if h == nil || h.authService == nil { + return nil + } + return h.authService.EntClient() +} + +func (h *AuthHandler) upsertPendingOAuthAdoptionDecision( + c *gin.Context, + sessionID int64, + req oauthAdoptionDecisionRequest, +) (*dbent.IdentityAdoptionDecision, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + existing, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(sessionID)). + Only(c.Request.Context()) + if err != nil && !dbent.IsNotFound(err) { + return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_LOAD_FAILED", "failed to load oauth profile adoption decision").WithCause(err) + } + if existing != nil && !req.hasDecision() { + return existing, nil + } + if existing == nil && !req.hasDecision() { + return nil, nil + } + + input := service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: sessionID, + } + if existing != nil { + input.AdoptDisplayName = existing.AdoptDisplayName + input.AdoptAvatar = existing.AdoptAvatar + input.IdentityID = existing.IdentityID + } + if req.AdoptDisplayName != nil { + input.AdoptDisplayName = *req.AdoptDisplayName + } + if req.AdoptAvatar != nil { + input.AdoptAvatar = *req.AdoptAvatar + } + + svc, err := h.pendingIdentityService() + if err != nil { + return nil, err + } + decision, err := svc.UpsertAdoptionDecision(c.Request.Context(), input) + if err != nil { + return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err) + } + return decision, nil +} + +func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) { + if session == nil { + return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid") + } + if session.TargetUserID != nil && *session.TargetUserID > 0 { + return *session.TargetUserID, nil + } + email := strings.TrimSpace(session.ResolvedEmail) + if email == "" { + return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing") + } + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(email)). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found") + } + return 0, err + } + return userEntity.ID, nil +} + +func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string { + if session == nil { + return nil + } + switch strings.TrimSpace(session.ProviderType) { + case "oidc": + issuer := strings.TrimSpace(session.ProviderKey) + if issuer == "" { + issuer = pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer") + } + if issuer == "" { + return nil + } + return &issuer + default: + issuer := pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer") + if issuer == "" { + return nil + } + return &issuer + } +} + +func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) { + client := tx.Client() + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)), + ). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if identity != nil { + if identity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + return identity, nil + } + + create := client.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType(strings.TrimSpace(session.ProviderType)). + SetProviderKey(strings.TrimSpace(session.ProviderKey)). + SetProviderSubject(strings.TrimSpace(session.ProviderSubject)). + SetMetadata(cloneOAuthMetadata(session.UpstreamIdentityClaims)) + if issuer := oauthIdentityIssuer(session); issuer != nil { + create = create.SetIssuer(strings.TrimSpace(*issuer)) + } + return create.Save(ctx) +} + +func applyPendingOAuthAdoption( + ctx context.Context, + client *dbent.Client, + session *dbent.PendingAuthSession, + decision *dbent.IdentityAdoptionDecision, + overrideUserID *int64, +) error { + if client == nil || session == nil || decision == nil { + return nil + } + if !decision.AdoptDisplayName && !decision.AdoptAvatar { + return nil + } + + targetUserID := int64(0) + if overrideUserID != nil && *overrideUserID > 0 { + targetUserID = *overrideUserID + } else { + resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session) + if err != nil { + return err + } + targetUserID = resolvedUserID + } + + adoptedDisplayName := "" + if decision.AdoptDisplayName { + adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name")) + } + adoptedAvatarURL := "" + if decision.AdoptAvatar { + adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url") + } + + tx, err := client.Tx(ctx) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + if decision.AdoptDisplayName && adoptedDisplayName != "" { + if err := tx.Client().User.UpdateOneID(targetUserID). + SetUsername(adoptedDisplayName). + Exec(ctx); err != nil { + return err + } + } + + identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID) + if err != nil { + return err + } + + metadata := cloneOAuthMetadata(identity.Metadata) + for key, value := range session.UpstreamIdentityClaims { + metadata[key] = value + } + if decision.AdoptDisplayName && adoptedDisplayName != "" { + metadata["display_name"] = adoptedDisplayName + } + if decision.AdoptAvatar && adoptedAvatarURL != "" { + metadata["avatar_url"] = adoptedAvatarURL + } + + updateIdentity := tx.Client().AuthIdentity.UpdateOneID(identity.ID).SetMetadata(metadata) + if issuer := oauthIdentityIssuer(session); issuer != nil { + updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer)) + } + if _, err := updateIdentity.Save(ctx); err != nil { + return err + } + + if decision.IdentityID == nil || *decision.IdentityID != identity.ID { + if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID). + SetIdentityID(identity.ID). + Save(ctx); err != nil { + return err + } + } + + return tx.Commit() +} + func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) { if len(payload) == 0 || len(upstream) == 0 { return @@ -206,6 +505,11 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie) } + adoptionDecision, err := bindOptionalOAuthAdoptionDecision(c) + if err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } sessionToken, err := readOAuthPendingSessionCookie(c) if err != nil || strings.TrimSpace(sessionToken) == "" { @@ -248,9 +552,30 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims) if pendingSessionWantsInvitation(payload) { + if adoptionDecision.hasDecision() { + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision) + if err != nil { + response.ErrorFrom(c, err) + return + } + _ = decision + } response.Success(c, payload) return } + if !adoptionDecision.hasDecision() { + response.Success(c, payload) + return + } + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, session.TargetUserID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { clearCookies() diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 5517bae2..829fc217 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -1,9 +1,30 @@ package handler import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "net/http" + "net/http/httptest" "testing" + "time" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" ) func TestApplySuggestedProfileToCompletionResponse(t *testing.T) { @@ -38,3 +59,439 @@ func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t * require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"]) require.Equal(t, true, payload["adoption_required"]) } + +func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("linuxdo-123@linuxdo-connect.invalid"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("pending-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "Alice Example", + "suggested_avatar_url": "https://cdn.example/alice.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + "redirect": "/dashboard", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + previewRecorder := httptest.NewRecorder() + previewCtx, _ := gin.CreateTestContext(previewRecorder) + previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")}) + previewCtx.Request = previewReq + + handler.ExchangePendingOAuthCompletion(previewCtx) + + require.Equal(t, http.StatusOK, previewRecorder.Code) + previewData := decodeJSONResponseData(t, previewRecorder) + require.Equal(t, "Alice Example", previewData["suggested_display_name"]) + require.Equal(t, "https://cdn.example/alice.png", previewData["suggested_avatar_url"]) + require.Equal(t, true, previewData["adoption_required"]) + + storedUser, err := client.User.Get(ctx, userEntity.ID) + require.NoError(t, err) + require.Equal(t, "legacy-name", storedUser.Username) + + previewSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, previewSession.ConsumedAt) + + body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`) + finalizeRecorder := httptest.NewRecorder() + finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder) + finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + finalizeReq.Header.Set("Content-Type", "application/json") + finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")}) + finalizeCtx.Request = finalizeReq + + handler.ExchangePendingOAuthCompletion(finalizeCtx) + + require.Equal(t, http.StatusOK, finalizeRecorder.Code) + + storedUser, err = client.User.Get(ctx, userEntity.ID) + require.NoError(t, err) + require.Equal(t, "Alice Example", storedUser.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "Alice Example", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"]) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 0, + UserConcurrency: 1, + }, + } + settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), + }, + }, cfg) + authSvc := service.NewAuthService( + client, + &oauthPendingFlowUserRepo{client: client}, + nil, + &oauthPendingFlowRefreshTokenCacheStub{}, + cfg, + settingSvc, + nil, + nil, + nil, + nil, + nil, + ) + + return &AuthHandler{ + authService: authSvc, + settingSvc: settingSvc, + }, client +} + +func boolSettingValue(v bool) string { + if v { + return "true" + } + return "false" +} + +func boolPtr(v bool) *bool { + return &v +} + +type oauthPendingFlowSettingRepoStub struct { + values map[string]string +} + +func (s *oauthPendingFlowSettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + return nil, service.ErrSettingNotFound +} + +func (s *oauthPendingFlowSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + value, ok := s.values[key] + if !ok { + return "", service.ErrSettingNotFound + } + return value, nil +} + +func (s *oauthPendingFlowSettingRepoStub) Set(context.Context, string, string) error { + return nil +} + +func (s *oauthPendingFlowSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + result[key] = value + } + } + return result, nil +} + +func (s *oauthPendingFlowSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + return nil +} + +func (s *oauthPendingFlowSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + result := make(map[string]string, len(s.values)) + for key, value := range s.values { + result[key] = value + } + return result, nil +} + +func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error { + return nil +} + +type oauthPendingFlowRefreshTokenCacheStub struct{} + +func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) { + return nil, service.ErrRefreshTokenNotFound +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { + return nil, nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { + return false, nil +} + +func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any { + t.Helper() + + var envelope struct { + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &envelope)) + return envelope.Data +} + +func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any { + t.Helper() + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + return payload +} + +type oauthPendingFlowUserRepo struct { + client *dbent.Client +} + +func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error { + entity, err := r.client.User.Create(). + SetEmail(user.Email). + SetUsername(user.Username). + SetNotes(user.Notes). + SetPasswordHash(user.PasswordHash). + SetRole(user.Role). + SetBalance(user.Balance). + SetConcurrency(user.Concurrency). + SetStatus(user.Status). + SetSignupSource(user.SignupSource). + SetNillableLastLoginAt(user.LastLoginAt). + SetNillableLastActiveAt(user.LastActiveAt). + Save(ctx) + if err != nil { + return err + } + user.ID = entity.ID + user.CreatedAt = entity.CreatedAt + user.UpdatedAt = entity.UpdatedAt + return nil +} + +func (r *oauthPendingFlowUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) { + entity, err := r.client.User.Get(ctx, id) + if err != nil { + if dbent.IsNotFound(err) { + return nil, service.ErrUserNotFound + } + return nil, err + } + return oauthPendingFlowServiceUser(entity), nil +} + +func (r *oauthPendingFlowUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) { + entity, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, service.ErrUserNotFound + } + return nil, err + } + return oauthPendingFlowServiceUser(entity), nil +} + +func (r *oauthPendingFlowUserRepo) GetFirstAdmin(context.Context) (*service.User, error) { + panic("unexpected GetFirstAdmin call") +} + +func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.User) error { + entity, err := r.client.User.UpdateOneID(user.ID). + SetEmail(user.Email). + SetUsername(user.Username). + SetNotes(user.Notes). + SetPasswordHash(user.PasswordHash). + SetRole(user.Role). + SetBalance(user.Balance). + SetConcurrency(user.Concurrency). + SetStatus(user.Status). + SetSignupSource(user.SignupSource). + SetNillableLastLoginAt(user.LastLoginAt). + SetNillableLastActiveAt(user.LastActiveAt). + Save(ctx) + if err != nil { + return err + } + user.UpdatedAt = entity.UpdatedAt + return nil +} + +func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error { + return r.client.User.DeleteOneID(id).Exec(ctx) +} + +func (r *oauthPendingFlowUserRepo) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) { + return nil, service.ErrUserNotFound +} + +func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + panic("unexpected UpsertUserAvatar call") +} + +func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(context.Context, int64) error { + return nil +} + +func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (r *oauthPendingFlowUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (r *oauthPendingFlowUserRepo) UpdateBalance(context.Context, int64, float64) error { + panic("unexpected UpdateBalance call") +} + +func (r *oauthPendingFlowUserRepo) DeductBalance(context.Context, int64, float64) error { + panic("unexpected DeductBalance call") +} + +func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int) error { + panic("unexpected UpdateConcurrency call") +} + +func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { + count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx) + return count > 0, err +} + +func (r *oauthPendingFlowUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + panic("unexpected RemoveGroupFromAllowedGroups call") +} + +func (r *oauthPendingFlowUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { + panic("unexpected AddGroupToAllowedGroups call") +} + +func (r *oauthPendingFlowUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + panic("unexpected RemoveGroupFromUserAllowedGroups call") +} + +func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { + panic("unexpected UpdateTotpSecret call") +} + +func (r *oauthPendingFlowUserRepo) EnableTotp(context.Context, int64) error { + panic("unexpected EnableTotp call") +} + +func (r *oauthPendingFlowUserRepo) DisableTotp(context.Context, int64) error { + panic("unexpected DisableTotp call") +} + +func oauthPendingFlowServiceUser(entity *dbent.User) *service.User { + if entity == nil { + return nil + } + return &service.User{ + ID: entity.ID, + Email: entity.Email, + Username: entity.Username, + Notes: entity.Notes, + PasswordHash: entity.PasswordHash, + Role: entity.Role, + Balance: entity.Balance, + Concurrency: entity.Concurrency, + Status: entity.Status, + SignupSource: entity.SignupSource, + LastLoginAt: entity.LastLoginAt, + LastActiveAt: entity.LastActiveAt, + CreatedAt: entity.CreatedAt, + UpdatedAt: entity.UpdatedAt, + } +} diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index e3694c8f..ceda633c 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -326,7 +326,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { ) // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if errors.Is(err, service.ErrOAuthInvitationRequired) { if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ @@ -371,6 +371,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { ProviderKey: issuer, ProviderSubject: subject, }, + TargetUserID: &user.ID, ResolvedEmail: email, RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, @@ -399,7 +400,9 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { } type completeOIDCOAuthRequest struct { - InvitationCode string `json:"invitation_code" binding:"required"` + InvitationCode string `json:"invitation_code" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } // CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating @@ -447,11 +450,23 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { return } - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) if err != nil { response.ErrorFrom(c, err) return } + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + AdoptDisplayName: req.AdoptDisplayName, + AdoptAvatar: req.AdoptAvatar, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie) diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index c389db51..9107e13a 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -1,6 +1,7 @@ package handler import ( + "bytes" "context" "crypto/rand" "crypto/rsa" @@ -12,7 +13,13 @@ import ( "testing" "time" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" ) @@ -123,3 +130,80 @@ func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK { E: e, } } + +func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("oidc-complete-session"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-subject-1"). + SetResolvedEmail("93a310f4c1944c5bbd2e246df1f76485@oidc-connect.invalid"). + SetBrowserSessionKey("oidc-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "issuer": "https://issuer.example.com", + "suggested_display_name": "OIDC Display", + "suggested_avatar_url": "https://cdn.example/oidc.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptAvatar: true, + }) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser")}) + c.Request = req + + handler.CompleteOIDCOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "OIDC Display", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example.com"), + authidentity.ProviderSubjectEQ("oidc-subject-1"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "OIDC Display", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/oidc.png", identity.Metadata["avatar_url"]) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go new file mode 100644 index 00000000..867a77a1 --- /dev/null +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -0,0 +1,618 @@ +package handler + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +const ( + wechatOAuthCookiePath = "/api/v1/auth/oauth/wechat" + wechatOAuthCookieMaxAgeSec = 10 * 60 + wechatOAuthStateCookieName = "wechat_oauth_state" + wechatOAuthRedirectCookieName = "wechat_oauth_redirect" + wechatOAuthIntentCookieName = "wechat_oauth_intent" + wechatOAuthModeCookieName = "wechat_oauth_mode" + wechatOAuthDefaultRedirectTo = "/dashboard" + wechatOAuthDefaultFrontendCB = "/auth/wechat/callback" + wechatOAuthProviderKey = "wechat-main" + + wechatOAuthIntentLogin = "login" + wechatOAuthIntentBind = "bind_current_user" + wechatOAuthIntentAdoptEmail = "adopt_existing_user_by_email" +) + +var ( + wechatOAuthAccessTokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token" + wechatOAuthUserInfoURL = "https://api.weixin.qq.com/sns/userinfo" +) + +type wechatOAuthConfig struct { + mode string + appID string + appSecret string + authorizeURL string + scope string + redirectURI string + frontendCallback string +} + +type wechatOAuthTokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + OpenID string `json:"openid"` + Scope string `json:"scope"` + UnionID string `json:"unionid"` + ErrCode int64 `json:"errcode"` + ErrMsg string `json:"errmsg"` +} + +type wechatOAuthUserInfoResponse struct { + OpenID string `json:"openid"` + Nickname string `json:"nickname"` + HeadImgURL string `json:"headimgurl"` + UnionID string `json:"unionid"` + ErrCode int64 `json:"errcode"` + ErrMsg string `json:"errmsg"` +} + +// WeChatOAuthStart starts the WeChat OAuth login flow and stores the short-lived +// browser cookies required by the rebuild pending-auth bridge. +func (h *AuthHandler) WeChatOAuthStart(c *gin.Context) { + cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), c.Query("mode"), c) + if err != nil { + response.ErrorFrom(c, err) + return + } + + state, err := oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err)) + return + } + + redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect")) + if redirectTo == "" { + redirectTo = wechatOAuthDefaultRedirectTo + } + + browserSessionKey, err := generateOAuthPendingBrowserSession() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err)) + return + } + + intent := normalizeWeChatOAuthIntent(c.Query("intent")) + secureCookie := isRequestHTTPS(c) + wechatSetCookie(c, wechatOAuthStateCookieName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatSetCookie(c, wechatOAuthRedirectCookieName, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatSetCookie(c, wechatOAuthIntentCookieName, encodeCookieValue(intent), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatSetCookie(c, wechatOAuthModeCookieName, encodeCookieValue(cfg.mode), wechatOAuthCookieMaxAgeSec, secureCookie) + setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie) + clearOAuthPendingSessionCookie(c, secureCookie) + + authURL, err := buildWeChatAuthorizeURL(cfg, state) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err)) + return + } + + c.Redirect(http.StatusFound, authURL) +} + +// WeChatOAuthCallback exchanges the code with WeChat, resolves openid/unionid, +// and stores the result in the unified pending-auth flow. +func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { + frontendCallback := wechatOAuthFrontendCallback() + + if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { + redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) + return + } + + code := strings.TrimSpace(c.Query("code")) + state := strings.TrimSpace(c.Query("state")) + if code == "" || state == "" { + redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "") + return + } + + secureCookie := isRequestHTTPS(c) + defer func() { + wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie) + }() + + expectedState, err := readCookieDecoded(c, wechatOAuthStateCookieName) + if err != nil || expectedState == "" || state != expectedState { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "") + return + } + + redirectTo, _ := readCookieDecoded(c, wechatOAuthRedirectCookieName) + redirectTo = sanitizeFrontendRedirectPath(redirectTo) + if redirectTo == "" { + redirectTo = wechatOAuthDefaultRedirectTo + } + browserSessionKey, _ := readOAuthPendingBrowserCookie(c) + if strings.TrimSpace(browserSessionKey) == "" { + redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "") + return + } + + intent, _ := readCookieDecoded(c, wechatOAuthIntentCookieName) + mode, err := readCookieDecoded(c, wechatOAuthModeCookieName) + if err != nil || strings.TrimSpace(mode) == "" { + redirectOAuthError(c, frontendCallback, "invalid_state", "missing oauth mode", "") + return + } + + cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), mode, c) + if err != nil { + redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + + tokenResp, userInfo, err := fetchWeChatOAuthIdentity(c.Request.Context(), cfg, code) + if err != nil { + redirectOAuthError(c, frontendCallback, "provider_error", "wechat_identity_fetch_failed", singleLine(err.Error())) + return + } + + unionid := strings.TrimSpace(firstNonEmpty(userInfo.UnionID, tokenResp.UnionID)) + openid := strings.TrimSpace(firstNonEmpty(userInfo.OpenID, tokenResp.OpenID)) + providerSubject := firstNonEmpty(unionid, openid) + if providerSubject == "" { + redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_subject", "") + return + } + + username := firstNonEmpty(userInfo.Nickname, wechatFallbackUsername(providerSubject)) + email := wechatSyntheticEmail(providerSubject) + upstreamClaims := map[string]any{ + "email": email, + "username": username, + "subject": providerSubject, + "openid": openid, + "unionid": unionid, + "mode": cfg.mode, + "suggested_display_name": strings.TrimSpace(userInfo.Nickname), + "suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL), + } + + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") + if err != nil { + if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, err); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) +} + +type completeWeChatOAuthRequest struct { + InvitationCode string `json:"invitation_code" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +// CompleteWeChatOAuthRegistration completes a pending WeChat OAuth registration by +// validating the invitation code and consuming the current pending browser session. +// POST /api/v1/auth/oauth/wechat/complete-registration +func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { + var req completeWeChatOAuthRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()}) + return + } + + secureCookie := isRequestHTTPS(c) + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + pendingSvc, err := h.pendingIdentityService() + if err != nil { + response.ErrorFrom(c, err) + return + } + session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + + email := strings.TrimSpace(session.ResolvedEmail) + username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") + if email == "" || username == "" { + response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")) + return + } + + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + if err != nil { + response.ErrorFrom(c, err) + return + } + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + AdoptDisplayName: req.AdoptDisplayName, + AdoptAvatar: req.AdoptAvatar, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + + c.JSON(http.StatusOK, gin.H{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + }) +} + +func (h *AuthHandler) createWeChatPendingSession( + c *gin.Context, + intent string, + providerSubject string, + email string, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, + tokenPair *service.TokenPair, + authErr error, +) error { + completionResponse := map[string]any{ + "redirect": redirectTo, + } + if authErr != nil { + if errors.Is(authErr, service.ErrOAuthInvitationRequired) { + completionResponse["error"] = "invitation_required" + } else { + return authErr + } + } else if tokenPair != nil { + completionResponse["access_token"] = tokenPair.AccessToken + completionResponse["refresh_token"] = tokenPair.RefreshToken + completionResponse["expires_in"] = tokenPair.ExpiresIn + completionResponse["token_type"] = "Bearer" + } + + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: intent, + Identity: service.PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: wechatOAuthProviderKey, + ProviderSubject: providerSubject, + }, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: completionResponse, + }) +} + +func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) { + mode, err := resolveWeChatOAuthMode(rawMode, c) + if err != nil { + return wechatOAuthConfig{}, err + } + + apiBaseURL := "" + if h != nil && h.settingSvc != nil { + settings, err := h.settingSvc.GetAllSettings(ctx) + if err == nil && settings != nil { + apiBaseURL = strings.TrimSpace(settings.APIBaseURL) + } + } + + cfg := wechatOAuthConfig{ + mode: mode, + redirectURI: resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback"), + frontendCallback: wechatOAuthFrontendCallback(), + } + + switch mode { + case "mp": + cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) + cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) + cfg.authorizeURL = "https://open.weixin.qq.com/connect/oauth2/authorize" + cfg.scope = "snsapi_userinfo" + default: + cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) + cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) + cfg.authorizeURL = "https://open.weixin.qq.com/connect/qrconnect" + cfg.scope = "snsapi_login" + } + + if cfg.appID == "" || cfg.appSecret == "" { + return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled") + } + if strings.TrimSpace(cfg.redirectURI) == "" { + return wechatOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured") + } + + return cfg, nil +} + +func wechatOAuthFrontendCallback() string { + return firstNonEmpty(strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")), wechatOAuthDefaultFrontendCB) +} + +func resolveWeChatOAuthMode(rawMode string, c *gin.Context) (string, error) { + mode := strings.ToLower(strings.TrimSpace(rawMode)) + if mode == "" { + if isWeChatBrowserRequest(c) { + return "mp", nil + } + return "open", nil + } + if mode != "open" && mode != "mp" { + return "", infraerrors.BadRequest("INVALID_MODE", "wechat oauth mode must be open or mp") + } + return mode, nil +} + +func isWeChatBrowserRequest(c *gin.Context) bool { + if c == nil || c.Request == nil { + return false + } + return strings.Contains(strings.ToLower(strings.TrimSpace(c.GetHeader("User-Agent"))), "micromessenger") +} + +func normalizeWeChatOAuthIntent(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "", "login": + return wechatOAuthIntentLogin + case "bind", "bind_current_user": + return wechatOAuthIntentBind + case "adopt", "adopt_existing_user_by_email": + return wechatOAuthIntentAdoptEmail + default: + return wechatOAuthIntentLogin + } +} + +func buildWeChatAuthorizeURL(cfg wechatOAuthConfig, state string) (string, error) { + u, err := url.Parse(cfg.authorizeURL) + if err != nil { + return "", fmt.Errorf("parse authorize url: %w", err) + } + query := u.Query() + query.Set("appid", cfg.appID) + query.Set("redirect_uri", cfg.redirectURI) + query.Set("response_type", "code") + query.Set("scope", cfg.scope) + query.Set("state", state) + u.RawQuery = query.Encode() + u.Fragment = "wechat_redirect" + return u.String(), nil +} + +func resolveWeChatOAuthAbsoluteURL(apiBaseURL string, c *gin.Context, callbackPath string) string { + callbackPath = strings.TrimSpace(callbackPath) + if callbackPath == "" { + return "" + } + + if raw := strings.TrimSpace(apiBaseURL); raw != "" { + if parsed, err := url.Parse(raw); err == nil && parsed.Scheme != "" && parsed.Host != "" { + basePath := strings.TrimRight(parsed.EscapedPath(), "/") + targetPath := callbackPath + if basePath != "" && strings.HasSuffix(basePath, "/api/v1") && strings.HasPrefix(callbackPath, "/api/v1") { + targetPath = basePath + strings.TrimPrefix(callbackPath, "/api/v1") + } else if basePath != "" { + targetPath = basePath + callbackPath + } + return parsed.Scheme + "://" + parsed.Host + targetPath + } + } + + if c == nil || c.Request == nil { + return "" + } + scheme := "http" + if isRequestHTTPS(c) { + scheme = "https" + } + host := strings.TrimSpace(c.Request.Host) + if forwardedHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); forwardedHost != "" { + host = forwardedHost + } + if host == "" { + return "" + } + return scheme + "://" + host + callbackPath +} + +func fetchWeChatOAuthIdentity(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, *wechatOAuthUserInfoResponse, error) { + tokenResp, err := exchangeWeChatOAuthCode(ctx, cfg, code) + if err != nil { + return nil, nil, err + } + userInfo, err := fetchWeChatUserInfo(ctx, tokenResp) + if err != nil { + return nil, nil, err + } + return tokenResp, userInfo, nil +} + +func exchangeWeChatOAuthCode(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, error) { + endpoint, err := url.Parse(wechatOAuthAccessTokenURL) + if err != nil { + return nil, fmt.Errorf("parse wechat access token url: %w", err) + } + + query := endpoint.Query() + query.Set("appid", cfg.appID) + query.Set("secret", cfg.appSecret) + query.Set("code", strings.TrimSpace(code)) + query.Set("grant_type", "authorization_code") + endpoint.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) + if err != nil { + return nil, fmt.Errorf("build wechat access token request: %w", err) + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("request wechat access token: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read wechat access token response: %w", err) + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("wechat access token status=%d", resp.StatusCode) + } + + var tokenResp wechatOAuthTokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("decode wechat access token response: %w", err) + } + if tokenResp.ErrCode != 0 { + return nil, fmt.Errorf("wechat access token error=%d %s", tokenResp.ErrCode, strings.TrimSpace(tokenResp.ErrMsg)) + } + if strings.TrimSpace(tokenResp.AccessToken) == "" { + return nil, fmt.Errorf("wechat access token missing access_token") + } + return &tokenResp, nil +} + +func fetchWeChatUserInfo(ctx context.Context, tokenResp *wechatOAuthTokenResponse) (*wechatOAuthUserInfoResponse, error) { + if tokenResp == nil { + return nil, fmt.Errorf("wechat token response is nil") + } + + endpoint, err := url.Parse(wechatOAuthUserInfoURL) + if err != nil { + return nil, fmt.Errorf("parse wechat userinfo url: %w", err) + } + query := endpoint.Query() + query.Set("access_token", strings.TrimSpace(tokenResp.AccessToken)) + query.Set("openid", strings.TrimSpace(tokenResp.OpenID)) + query.Set("lang", "zh_CN") + endpoint.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) + if err != nil { + return nil, fmt.Errorf("build wechat userinfo request: %w", err) + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("request wechat userinfo: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read wechat userinfo response: %w", err) + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("wechat userinfo status=%d", resp.StatusCode) + } + + var userInfo wechatOAuthUserInfoResponse + if err := json.Unmarshal(body, &userInfo); err != nil { + return nil, fmt.Errorf("decode wechat userinfo response: %w", err) + } + if userInfo.ErrCode != 0 { + return nil, fmt.Errorf("wechat userinfo error=%d %s", userInfo.ErrCode, strings.TrimSpace(userInfo.ErrMsg)) + } + return &userInfo, nil +} + +func wechatSyntheticEmail(subject string) string { + subject = strings.TrimSpace(subject) + if subject == "" { + return "" + } + return "wechat-" + subject + service.WeChatConnectSyntheticEmailDomain +} + +func wechatFallbackUsername(subject string) string { + subject = strings.TrimSpace(subject) + if subject == "" { + return "wechat_user" + } + return "wechat_" + truncateFragmentValue(subject) +} + +func wechatSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: value, + Path: wechatOAuthCookiePath, + MaxAge: maxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func wechatClearCookie(c *gin.Context, name string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: "", + Path: wechatOAuthCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go new file mode 100644 index 00000000..1a765dcc --- /dev/null +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -0,0 +1,411 @@ +//go:build unit + +package handler + +import ( + "bytes" + "context" + "database/sql" + "encoding/base64" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/repository" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil) + c.Request.Host = "api.example.com" + + handler := &AuthHandler{} + handler.WeChatOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.NotEmpty(t, location) + require.Contains(t, location, "open.weixin.qq.com") + require.Contains(t, location, "appid=wx-open-app") + require.Contains(t, location, "scope=snsapi_login") + + cookies := recorder.Result().Cookies() + require.NotEmpty(t, findCookie(cookies, wechatOAuthStateCookieName)) + require.NotEmpty(t, findCookie(cookies, wechatOAuthRedirectCookieName)) + require.NotEmpty(t, findCookie(cookies, wechatOAuthModeCookieName)) + require.NotEmpty(t, findCookie(cookies, oauthPendingBrowserCookieName)) +} + +func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + ctx := context.Background() + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "wechat", session.ProviderType) + require.Equal(t, "wechat-main", session.ProviderKey) + require.Equal(t, "union-456", session.ProviderSubject) + require.Equal(t, "wechat-union-456@wechat-connect.invalid", session.ResolvedEmail) + require.Equal(t, "WeChat Nick", session.UpstreamIdentityClaims["suggested_display_name"]) + require.Equal(t, "https://cdn.example/avatar.png", session.UpstreamIdentityClaims["suggested_avatar_url"]) + require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"]) + require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"]) +} + +func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, true) + defer client.Close() + + ctx := context.Background() + redeemRepo := repository.NewRedeemCodeRepository(client) + require.NoError(t, redeemRepo.Create(ctx, &service.RedeemCode{ + Code: "invite-1", + Type: service.RedeemTypeInvitation, + Status: service.StatusUnused, + })) + + callbackRecorder := httptest.NewRecorder() + callbackCtx, _ := gin.CreateTestContext(callbackRecorder) + callbackReq := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + callbackReq.Host = "api.example.com" + callbackReq.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + callbackReq.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + callbackReq.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + callbackReq.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + callbackCtx.Request = callbackReq + + handler.WeChatOAuthCallback(callbackCtx) + + require.Equal(t, http.StatusFound, callbackRecorder.Code) + require.Equal(t, "/auth/wechat/callback", callbackRecorder.Header().Get("Location")) + + sessionCookie := findCookie(callbackRecorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + sessionToken := decodeCookieValueForTest(t, sessionCookie.Value) + + pendingSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(sessionToken)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "invitation_required", pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["error"]) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true,"adopt_avatar":true}`) + completeRecorder := httptest.NewRecorder() + completeCtx, _ := gin.CreateTestContext(completeRecorder) + completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body) + completeReq.Header.Set("Content-Type", "application/json") + completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(sessionToken)}) + completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-123")}) + completeCtx.Request = completeReq + + handler.CompleteWeChatOAuthRegistration(completeCtx) + + require.Equal(t, http.StatusOK, completeRecorder.Code) + responseData := decodeJSONBody(t, completeRecorder) + require.NotEmpty(t, responseData["access_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ("wechat-union-456@wechat-connect.invalid")). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "WeChat Display", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ("wechat-main"), + authidentity.ProviderSubjectEQ("union-456"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "WeChat Display", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"]) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(pendingSession.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + + userRepo := &oauthPendingFlowUserRepo{client: client} + redeemRepo := repository.NewRedeemCodeRepository(client) + settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), + }, + }, &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 0, + UserConcurrency: 1, + }, + }) + + authSvc := service.NewAuthService( + client, + userRepo, + redeemRepo, + &wechatOAuthRefreshTokenCacheStub{}, + &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 0, + UserConcurrency: 1, + }, + }, + settingSvc, + nil, + nil, + nil, + nil, + nil, + ) + + return &AuthHandler{ + authService: authSvc, + settingSvc: settingSvc, + }, client +} + +func encodedCookie(name, value string) *http.Cookie { + return &http.Cookie{ + Name: name, + Value: encodeCookieValue(value), + Path: "/", + } +} + +func findCookie(cookies []*http.Cookie, name string) *http.Cookie { + for _, cookie := range cookies { + if cookie.Name == name { + return cookie + } + } + return nil +} + +func decodeCookieValueForTest(t *testing.T, value string) string { + t.Helper() + raw, err := base64.RawURLEncoding.DecodeString(value) + require.NoError(t, err) + return string(raw) +} + +type wechatOAuthSettingRepoStub struct { + values map[string]string +} + +func (s *wechatOAuthSettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + return nil, service.ErrSettingNotFound +} + +func (s *wechatOAuthSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + value, ok := s.values[key] + if !ok { + return "", service.ErrSettingNotFound + } + return value, nil +} + +func (s *wechatOAuthSettingRepoStub) Set(context.Context, string, string) error { + return nil +} + +func (s *wechatOAuthSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + result[key] = value + } + } + return result, nil +} + +func (s *wechatOAuthSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + return nil +} + +func (s *wechatOAuthSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + result := make(map[string]string, len(s.values)) + for key, value := range s.values { + result[key] = value + } + return result, nil +} + +func (s *wechatOAuthSettingRepoStub) Delete(context.Context, string) error { + return nil +} + +type wechatOAuthRefreshTokenCacheStub struct{} + +func (s *wechatOAuthRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) { + return nil, service.ErrRefreshTokenNotFound +} + +func (s *wechatOAuthRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { + return nil, nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { + return false, nil +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 3659e79b..f44b3e3b 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -189,6 +189,7 @@ type PublicSettings struct { CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` SoraClientEnabled bool `json:"sora_client_enabled"` diff --git a/backend/internal/handler/payment_webhook_handler.go b/backend/internal/handler/payment_webhook_handler.go index 8a83bfeb..9fdefa93 100644 --- a/backend/internal/handler/payment_webhook_handler.go +++ b/backend/internal/handler/payment_webhook_handler.go @@ -120,7 +120,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) // This allows looking up the correct provider instance before verification. func extractOutTradeNo(rawBody, providerKey string) string { switch providerKey { - case payment.TypeEasyPay: + case payment.TypeEasyPay, payment.TypeAlipay: values, err := url.ParseQuery(rawBody) if err == nil { return values.Get("out_trade_no") diff --git a/backend/internal/handler/payment_webhook_handler_test.go b/backend/internal/handler/payment_webhook_handler_test.go index bdef1766..6f448131 100644 --- a/backend/internal/handler/payment_webhook_handler_test.go +++ b/backend/internal/handler/payment_webhook_handler_test.go @@ -97,3 +97,37 @@ func TestWebhookConstants(t *testing.T) { assert.Equal(t, 200, webhookLogTruncateLen) }) } + +func TestExtractOutTradeNo(t *testing.T) { + tests := []struct { + name string + providerKey string + rawBody string + want string + }{ + { + name: "easypay query payload", + providerKey: "easypay", + rawBody: "out_trade_no=sub2_123&trade_status=TRADE_SUCCESS", + want: "sub2_123", + }, + { + name: "alipay query payload", + providerKey: "alipay", + rawBody: "notify_time=2026-04-20+12%3A00%3A00&out_trade_no=sub2_456", + want: "sub2_456", + }, + { + name: "unknown provider", + providerKey: "wxpay", + rawBody: "{}", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, extractOutTradeNo(tt.rawBody, tt.providerKey)) + }) + } +} diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 1717b7a1..c7bc3e2a 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -56,6 +56,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + WeChatOAuthEnabled: settings.WeChatOAuthEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled, OIDCOAuthProviderName: settings.OIDCOAuthProviderName, BackendModeEnabled: settings.BackendModeEnabled, diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 2535ea5e..904341d0 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -34,10 +34,16 @@ type ChangePasswordRequest struct { // UpdateProfileRequest represents the update profile request payload type UpdateProfileRequest struct { Username *string `json:"username"` + AvatarURL *string `json:"avatar_url"` BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` } +type userProfileResponse struct { + dto.User + AvatarURL string `json:"avatar_url,omitempty"` +} + // GetProfile handles getting user profile // GET /api/v1/users/me func (h *UserHandler) GetProfile(c *gin.Context) { @@ -47,13 +53,13 @@ func (h *UserHandler) GetProfile(c *gin.Context) { return } - userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID) + userData, err := h.userService.GetProfile(c.Request.Context(), subject.UserID) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, dto.UserFromService(userData)) + response.Success(c, userProfileResponseFromService(userData)) } // ChangePassword handles changing user password @@ -101,6 +107,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { svcReq := service.UpdateProfileRequest{ Username: req.Username, + AvatarURL: req.AvatarURL, BalanceNotifyEnabled: req.BalanceNotifyEnabled, BalanceNotifyThreshold: req.BalanceNotifyThreshold, } @@ -110,7 +117,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + response.Success(c, userProfileResponseFromService(updatedUser)) } // SendNotifyEmailCodeRequest represents the request to send notify email verification code @@ -176,7 +183,7 @@ func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + response.Success(c, userProfileResponseFromService(updatedUser)) } // RemoveNotifyEmailRequest represents the request to remove a notify email @@ -212,7 +219,7 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + response.Success(c, userProfileResponseFromService(updatedUser)) } // ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state @@ -248,5 +255,16 @@ func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + response.Success(c, userProfileResponseFromService(updatedUser)) +} + +func userProfileResponseFromService(user *service.User) userProfileResponse { + base := dto.UserFromService(user) + if base == nil { + return userProfileResponse{} + } + return userProfileResponse{ + User: *base, + AvatarURL: user.AvatarURL, + } } diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go new file mode 100644 index 00000000..1973f59e --- /dev/null +++ b/backend/internal/handler/user_handler_test.go @@ -0,0 +1,136 @@ +//go:build unit + +package handler + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type userHandlerRepoStub struct { + user *service.User +} + +func (s *userHandlerRepoStub) Create(context.Context, *service.User) error { return nil } +func (s *userHandlerRepoStub) GetByID(context.Context, int64) (*service.User, error) { + cloned := *s.user + return &cloned, nil +} +func (s *userHandlerRepoStub) GetByEmail(context.Context, string) (*service.User, error) { + cloned := *s.user + return &cloned, nil +} +func (s *userHandlerRepoStub) GetFirstAdmin(context.Context) (*service.User, error) { + cloned := *s.user + return &cloned, nil +} +func (s *userHandlerRepoStub) Update(_ context.Context, user *service.User) error { + cloned := *user + s.user = &cloned + return nil +} +func (s *userHandlerRepoStub) Delete(context.Context, int64) error { return nil } +func (s *userHandlerRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) { + if s.user == nil || s.user.AvatarURL == "" { + return nil, nil + } + return &service.UserAvatar{ + StorageProvider: s.user.AvatarSource, + URL: s.user.AvatarURL, + ContentType: s.user.AvatarMIME, + ByteSize: s.user.AvatarByteSize, + SHA256: s.user.AvatarSHA256, + }, nil +} +func (s *userHandlerRepoStub) UpsertUserAvatar(_ context.Context, _ int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + s.user.AvatarURL = input.URL + s.user.AvatarSource = input.StorageProvider + s.user.AvatarMIME = input.ContentType + s.user.AvatarByteSize = input.ByteSize + s.user.AvatarSHA256 = input.SHA256 + return &service.UserAvatar{ + StorageProvider: input.StorageProvider, + URL: input.URL, + ContentType: input.ContentType, + ByteSize: input.ByteSize, + SHA256: input.SHA256, + }, nil +} +func (s *userHandlerRepoStub) DeleteUserAvatar(context.Context, int64) error { + s.user.AvatarURL = "" + s.user.AvatarSource = "" + s.user.AvatarMIME = "" + s.user.AvatarByteSize = 0 + s.user.AvatarSHA256 = "" + return nil +} +func (s *userHandlerRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil } +func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil } +func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil } +func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil } +func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + return 0, nil +} +func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { + return nil +} +func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + return nil +} +func (s *userHandlerRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (s *userHandlerRepoStub) EnableTotp(context.Context, int64) error { return nil } +func (s *userHandlerRepoStub) DisableTotp(context.Context, int64) error { return nil } + +func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 11, + Email: "handler-avatar@example.com", + Username: "handler-avatar", + Role: service.RoleUser, + Status: service.StatusActive, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil) + + body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/user", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11}) + + handler.UpdateProfile(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + AvatarURL string `json:"avatar_url"` + Username string `json:"username"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, "https://cdn.example.com/avatar.png", resp.Data.AvatarURL) + require.Equal(t, "handler-avatar", resp.Data.Username) +} diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 38ea9bde..36d80309 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -149,6 +149,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se user.FieldBalanceNotifyThreshold, user.FieldBalanceNotifyExtraEmails, user.FieldTotalRecharged, + user.FieldSignupSource, + user.FieldLastLoginAt, + user.FieldLastActiveAt, ) }). WithGroup(func(q *dbent.GroupQuery) { @@ -656,6 +659,9 @@ func userEntityToService(u *dbent.User) *service.User { Balance: u.Balance, Concurrency: u.Concurrency, Status: u.Status, + SignupSource: u.SignupSource, + LastLoginAt: u.LastLoginAt, + LastActiveAt: u.LastActiveAt, TotpSecretEncrypted: u.TotpSecretEncrypted, TotpEnabled: u.TotpEnabled, TotpEnabledAt: u.TotpEnabledAt, diff --git a/backend/internal/repository/auth_identity_migration_report.go b/backend/internal/repository/auth_identity_migration_report.go new file mode 100644 index 00000000..70f298c1 --- /dev/null +++ b/backend/internal/repository/auth_identity_migration_report.go @@ -0,0 +1,148 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" +) + +type AuthIdentityMigrationReport struct { + ID int64 + ReportType string + ReportKey string + Details map[string]any + CreatedAt time.Time +} + +type AuthIdentityMigrationReportQuery struct { + ReportType string + Limit int + Offset int +} + +type AuthIdentityMigrationReportSummary struct { + Total int64 + ByType map[string]int64 +} + +func (r *userRepository) ListAuthIdentityMigrationReports(ctx context.Context, query AuthIdentityMigrationReportQuery) ([]AuthIdentityMigrationReport, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return nil, fmt.Errorf("sql executor is not configured") + } + + limit := query.Limit + if limit <= 0 { + limit = 100 + } + rows, err := exec.QueryContext(ctx, ` +SELECT id, report_type, report_key, details, created_at +FROM auth_identity_migration_reports +WHERE ($1 = '' OR report_type = $1) +ORDER BY created_at DESC, id DESC +LIMIT $2 OFFSET $3`, + strings.TrimSpace(query.ReportType), + limit, + query.Offset, + ) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + reports := make([]AuthIdentityMigrationReport, 0) + for rows.Next() { + report, scanErr := scanAuthIdentityMigrationReport(rows) + if scanErr != nil { + return nil, scanErr + } + reports = append(reports, report) + } + if err := rows.Err(); err != nil { + return nil, err + } + return reports, nil +} + +func (r *userRepository) GetAuthIdentityMigrationReport(ctx context.Context, reportType, reportKey string) (*AuthIdentityMigrationReport, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return nil, fmt.Errorf("sql executor is not configured") + } + + rows, err := exec.QueryContext(ctx, ` +SELECT id, report_type, report_key, details, created_at +FROM auth_identity_migration_reports +WHERE report_type = $1 AND report_key = $2 +LIMIT 1`, + strings.TrimSpace(reportType), + strings.TrimSpace(reportKey), + ) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + return nil, sql.ErrNoRows + } + report, err := scanAuthIdentityMigrationReport(rows) + if err != nil { + return nil, err + } + return &report, rows.Err() +} + +func (r *userRepository) SummarizeAuthIdentityMigrationReports(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return nil, fmt.Errorf("sql executor is not configured") + } + + rows, err := exec.QueryContext(ctx, ` +SELECT report_type, COUNT(*) +FROM auth_identity_migration_reports +GROUP BY report_type +ORDER BY report_type ASC`) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + summary := &AuthIdentityMigrationReportSummary{ + ByType: make(map[string]int64), + } + for rows.Next() { + var reportType string + var count int64 + if err := rows.Scan(&reportType, &count); err != nil { + return nil, err + } + summary.ByType[reportType] = count + summary.Total += count + } + if err := rows.Err(); err != nil { + return nil, err + } + return summary, nil +} + +func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) { + var ( + report AuthIdentityMigrationReport + details []byte + ) + if err := scanner.Scan(&report.ID, &report.ReportType, &report.ReportKey, &details, &report.CreatedAt); err != nil { + return AuthIdentityMigrationReport{}, err + } + report.Details = map[string]any{} + if len(details) > 0 { + if err := json.Unmarshal(details, &report.Details); err != nil { + return AuthIdentityMigrationReport{}, err + } + } + return report, nil +} diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go new file mode 100644 index 00000000..4ecae4a4 --- /dev/null +++ b/backend/internal/repository/user_profile_identity_repo.go @@ -0,0 +1,544 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "strings" + "time" + "unsafe" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +var ( + ErrAuthIdentityOwnershipConflict = infraerrors.Conflict( + "AUTH_IDENTITY_OWNERSHIP_CONFLICT", + "auth identity already belongs to another user", + ) + ErrAuthIdentityChannelOwnershipConflict = infraerrors.Conflict( + "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", + "auth identity channel already belongs to another user", + ) +) + +type ProviderGrantReason string + +const ( + ProviderGrantReasonSignup ProviderGrantReason = "signup" + ProviderGrantReasonFirstBind ProviderGrantReason = "first_bind" +) + +type AuthIdentityKey struct { + ProviderType string + ProviderKey string + ProviderSubject string +} + +type AuthIdentityChannelKey struct { + ProviderType string + ProviderKey string + Channel string + ChannelAppID string + ChannelSubject string +} + +type CreateAuthIdentityInput struct { + UserID int64 + Canonical AuthIdentityKey + Channel *AuthIdentityChannelKey + Issuer *string + VerifiedAt *time.Time + Metadata map[string]any + ChannelMetadata map[string]any +} + +type BindAuthIdentityInput = CreateAuthIdentityInput + +type CreateAuthIdentityResult struct { + Identity *dbent.AuthIdentity + Channel *dbent.AuthIdentityChannel +} + +func (r *CreateAuthIdentityResult) IdentityRef() AuthIdentityKey { + if r == nil || r.Identity == nil { + return AuthIdentityKey{} + } + return AuthIdentityKey{ + ProviderType: r.Identity.ProviderType, + ProviderKey: r.Identity.ProviderKey, + ProviderSubject: r.Identity.ProviderSubject, + } +} + +func (r *CreateAuthIdentityResult) ChannelRef() *AuthIdentityChannelKey { + if r == nil || r.Channel == nil { + return nil + } + return &AuthIdentityChannelKey{ + ProviderType: r.Channel.ProviderType, + ProviderKey: r.Channel.ProviderKey, + Channel: r.Channel.Channel, + ChannelAppID: r.Channel.ChannelAppID, + ChannelSubject: r.Channel.ChannelSubject, + } +} + +type UserAuthIdentityLookup struct { + User *dbent.User + Identity *dbent.AuthIdentity + Channel *dbent.AuthIdentityChannel +} + +type ProviderGrantRecordInput struct { + UserID int64 + ProviderType string + GrantReason ProviderGrantReason +} + +type IdentityAdoptionDecisionInput struct { + PendingAuthSessionID int64 + IdentityID *int64 + AdoptDisplayName bool + AdoptAvatar bool +} + +type sqlQueryExecutor interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) +} + +func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error { + if dbent.TxFromContext(ctx) != nil { + return fn(ctx) + } + + tx, err := r.client.Tx(ctx) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := fn(txCtx); err != nil { + return err + } + return tx.Commit() +} + +func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) { + client := clientFromContext(ctx, r.client) + + create := client.AuthIdentity.Create(). + SetUserID(input.UserID). + SetProviderType(strings.TrimSpace(input.Canonical.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Canonical.ProviderKey)). + SetProviderSubject(strings.TrimSpace(input.Canonical.ProviderSubject)). + SetMetadata(copyMetadata(input.Metadata)). + SetNillableIssuer(input.Issuer). + SetNillableVerifiedAt(input.VerifiedAt) + + identity, err := create.Save(ctx) + if err != nil { + return nil, err + } + + var channel *dbent.AuthIdentityChannel + if input.Channel != nil { + channel, err = client.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(strings.TrimSpace(input.Channel.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)). + SetChannel(strings.TrimSpace(input.Channel.Channel)). + SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)). + SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)). + SetMetadata(copyMetadata(input.ChannelMetadata)). + Save(ctx) + if err != nil { + return nil, err + } + } + + return &CreateAuthIdentityResult{Identity: identity, Channel: channel}, nil +} + +func (r *userRepository) GetUserByCanonicalIdentity(ctx context.Context, key AuthIdentityKey) (*UserAuthIdentityLookup, error) { + identity, err := clientFromContext(ctx, r.client).AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(key.ProviderSubject)), + ). + WithUser(). + Only(ctx) + if err != nil { + return nil, err + } + + return &UserAuthIdentityLookup{ + User: identity.Edges.User, + Identity: identity, + }, nil +} + +func (r *userRepository) GetUserByChannelIdentity(ctx context.Context, key AuthIdentityChannelKey) (*UserAuthIdentityLookup, error) { + channel, err := clientFromContext(ctx, r.client).AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)), + authidentitychannel.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)), + authidentitychannel.ChannelEQ(strings.TrimSpace(key.Channel)), + authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(key.ChannelAppID)), + authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(key.ChannelSubject)), + ). + WithIdentity(func(q *dbent.AuthIdentityQuery) { + q.WithUser() + }). + Only(ctx) + if err != nil { + return nil, err + } + + return &UserAuthIdentityLookup{ + User: channel.Edges.Identity.Edges.User, + Identity: channel.Edges.Identity, + Channel: channel, + }, nil +} + +func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) { + var result *CreateAuthIdentityResult + err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error { + client := clientFromContext(txCtx, r.client) + canonical := input.Canonical + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(canonical.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)), + ). + Only(txCtx) + if err != nil && !dbent.IsNotFound(err) { + return err + } + if identity != nil && identity.UserID != input.UserID { + return ErrAuthIdentityOwnershipConflict + } + if identity == nil { + identity, err = client.AuthIdentity.Create(). + SetUserID(input.UserID). + SetProviderType(strings.TrimSpace(canonical.ProviderType)). + SetProviderKey(strings.TrimSpace(canonical.ProviderKey)). + SetProviderSubject(strings.TrimSpace(canonical.ProviderSubject)). + SetMetadata(copyMetadata(input.Metadata)). + SetNillableIssuer(input.Issuer). + SetNillableVerifiedAt(input.VerifiedAt). + Save(txCtx) + if err != nil { + return err + } + } else { + update := client.AuthIdentity.UpdateOneID(identity.ID) + if input.Metadata != nil { + update = update.SetMetadata(copyMetadata(input.Metadata)) + } + if input.Issuer != nil { + update = update.SetIssuer(strings.TrimSpace(*input.Issuer)) + } + if input.VerifiedAt != nil { + update = update.SetVerifiedAt(*input.VerifiedAt) + } + identity, err = update.Save(txCtx) + if err != nil { + return err + } + } + + var channel *dbent.AuthIdentityChannel + if input.Channel != nil { + channel, err = client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)), + authidentitychannel.ProviderKeyEQ(strings.TrimSpace(input.Channel.ProviderKey)), + authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)), + authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)), + authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)), + ). + WithIdentity(). + Only(txCtx) + if err != nil && !dbent.IsNotFound(err) { + return err + } + if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != input.UserID { + return ErrAuthIdentityChannelOwnershipConflict + } + if channel == nil { + channel, err = client.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(strings.TrimSpace(input.Channel.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)). + SetChannel(strings.TrimSpace(input.Channel.Channel)). + SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)). + SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)). + SetMetadata(copyMetadata(input.ChannelMetadata)). + Save(txCtx) + if err != nil { + return err + } + } else { + update := client.AuthIdentityChannel.UpdateOneID(channel.ID). + SetIdentityID(identity.ID) + if input.ChannelMetadata != nil { + update = update.SetMetadata(copyMetadata(input.ChannelMetadata)) + } + channel, err = update.Save(txCtx) + if err != nil { + return err + } + } + } + + result = &CreateAuthIdentityResult{Identity: identity, Channel: channel} + return nil + }) + if err != nil { + return nil, err + } + return result, nil +} + +func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return false, fmt.Errorf("sql executor is not configured") + } + + result, err := exec.ExecContext(ctx, ` +INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason) +VALUES ($1, $2, $3) +ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`, + input.UserID, + strings.TrimSpace(input.ProviderType), + string(input.GrantReason), + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) { + client := clientFromContext(ctx, r.client) + current, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + now := time.Now().UTC() + if current == nil { + create := client.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(input.PendingAuthSessionID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar). + SetDecidedAt(now) + if input.IdentityID != nil { + create = create.SetIdentityID(*input.IdentityID) + } + return create.Save(ctx) + } + + update := client.IdentityAdoptionDecision.UpdateOneID(current.ID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar) + if input.IdentityID != nil { + update = update.SetIdentityID(*input.IdentityID) + } + return update.Save(ctx) +} + +func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) { + return clientFromContext(ctx, r.client).IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingAuthSessionID)). + Only(ctx) +} + +func (r *userRepository) UpdateUserLastLoginAt(ctx context.Context, userID int64, loginAt time.Time) error { + _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID). + SetLastLoginAt(loginAt). + Save(ctx) + return err +} + +func (r *userRepository) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error { + _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID). + SetLastActiveAt(activeAt). + Save(ctx) + return err +} + +func (r *userRepository) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) { + exec, err := r.userProfileIdentitySQL(ctx) + if err != nil { + return nil, err + } + + rows, err := exec.QueryContext(ctx, ` +SELECT storage_provider, storage_key, url, content_type, byte_size, sha256 +FROM user_avatars +WHERE user_id = $1`, userID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + return nil, rows.Err() + } + + var avatar service.UserAvatar + if err := rows.Scan( + &avatar.StorageProvider, + &avatar.StorageKey, + &avatar.URL, + &avatar.ContentType, + &avatar.ByteSize, + &avatar.SHA256, + ); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return &avatar, nil +} + +func (r *userRepository) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + exec, err := r.userProfileIdentitySQL(ctx) + if err != nil { + return nil, err + } + + _, err = exec.ExecContext(ctx, ` +INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at) +VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) +ON CONFLICT (user_id) DO UPDATE SET + storage_provider = EXCLUDED.storage_provider, + storage_key = EXCLUDED.storage_key, + url = EXCLUDED.url, + content_type = EXCLUDED.content_type, + byte_size = EXCLUDED.byte_size, + sha256 = EXCLUDED.sha256, + updated_at = NOW()`, + userID, + strings.TrimSpace(input.StorageProvider), + strings.TrimSpace(input.StorageKey), + strings.TrimSpace(input.URL), + strings.TrimSpace(input.ContentType), + input.ByteSize, + strings.TrimSpace(input.SHA256), + ) + if err != nil { + return nil, err + } + + return &service.UserAvatar{ + StorageProvider: strings.TrimSpace(input.StorageProvider), + StorageKey: strings.TrimSpace(input.StorageKey), + URL: strings.TrimSpace(input.URL), + ContentType: strings.TrimSpace(input.ContentType), + ByteSize: input.ByteSize, + SHA256: strings.TrimSpace(input.SHA256), + }, nil +} + +func (r *userRepository) DeleteUserAvatar(ctx context.Context, userID int64) error { + exec, err := r.userProfileIdentitySQL(ctx) + if err != nil { + return err + } + _, err = exec.ExecContext(ctx, `DELETE FROM user_avatars WHERE user_id = $1`, userID) + return err +} + +func (r *userRepository) attachUserAvatar(ctx context.Context, user *service.User) error { + if user == nil { + return nil + } + + avatar, err := r.GetUserAvatar(ctx, user.ID) + if err != nil { + return err + } + if avatar == nil { + return nil + } + + user.AvatarURL = avatar.URL + user.AvatarSource = avatar.StorageProvider + user.AvatarMIME = avatar.ContentType + user.AvatarByteSize = avatar.ByteSize + user.AvatarSHA256 = avatar.SHA256 + return nil +} + +func copyMetadata(in map[string]any) map[string]any { + if len(in) == 0 { + return map[string]any{} + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor { + if tx := dbent.TxFromContext(ctx); tx != nil { + if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil { + return exec + } + } + if fallback != nil { + return fallback + } + return sqlExecutorFromEntClient(client) +} + +func (r *userRepository) userProfileIdentitySQL(ctx context.Context) (sqlQueryExecutor, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return nil, fmt.Errorf("sql executor is not configured") + } + return exec, nil +} + +func sqlExecutorFromEntClient(client *dbent.Client) sqlQueryExecutor { + if client == nil { + return nil + } + + clientValue := reflect.ValueOf(client).Elem() + configValue := clientValue.FieldByName("config") + driverValue := configValue.FieldByName("driver") + if !driverValue.IsValid() { + return nil + } + + driver := reflect.NewAt(driverValue.Type(), unsafe.Pointer(driverValue.UnsafeAddr())).Elem().Interface() + exec, ok := driver.(sqlQueryExecutor) + if !ok { + return nil + } + return exec +} diff --git a/backend/internal/repository/user_profile_identity_repo_contract_test.go b/backend/internal/repository/user_profile_identity_repo_contract_test.go new file mode 100644 index 00000000..19022ec1 --- /dev/null +++ b/backend/internal/repository/user_profile_identity_repo_contract_test.go @@ -0,0 +1,428 @@ +//go:build integration + +package repository + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/suite" +) + +type UserProfileIdentityRepoSuite struct { + suite.Suite + ctx context.Context + client *dbent.Client + repo *userRepository +} + +func TestUserProfileIdentityRepoSuite(t *testing.T) { + suite.Run(t, new(UserProfileIdentityRepoSuite)) +} + +func (s *UserProfileIdentityRepoSuite) SetupTest() { + s.ctx = context.Background() + s.client = testEntClient(s.T()) + s.repo = newUserRepositoryWithSQL(s.client, integrationDB) + + _, err := integrationDB.ExecContext(s.ctx, ` +TRUNCATE TABLE + identity_adoption_decisions, + auth_identity_channels, + auth_identities, + pending_auth_sessions, + auth_identity_migration_reports, + user_provider_default_grants, + user_avatars +RESTART IDENTITY`) + s.Require().NoError(err) +} + +func (s *UserProfileIdentityRepoSuite) mustCreateUser(label string) *dbent.User { + s.T().Helper() + + user, err := s.client.User.Create(). + SetEmail(fmt.Sprintf("%s-%d@example.com", label, time.Now().UnixNano())). + SetPasswordHash("test-password-hash"). + SetRole("user"). + SetStatus("active"). + Save(s.ctx) + s.Require().NoError(err) + return user +} + +func (s *UserProfileIdentityRepoSuite) mustCreatePendingAuthSession(key AuthIdentityKey) *dbent.PendingAuthSession { + s.T().Helper() + + session, err := s.client.PendingAuthSession.Create(). + SetSessionToken(fmt.Sprintf("pending-%d", time.Now().UnixNano())). + SetIntent("bind_current_user"). + SetProviderType(key.ProviderType). + SetProviderKey(key.ProviderKey). + SetProviderSubject(key.ProviderSubject). + SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)). + SetUpstreamIdentityClaims(map[string]any{"provider_subject": key.ProviderSubject}). + SetLocalFlowState(map[string]any{"step": "pending"}). + Save(s.ctx) + s.Require().NoError(err) + return session +} + +func (s *UserProfileIdentityRepoSuite) TestCreateAndLookupCanonicalAndChannelIdentity() { + user := s.mustCreateUser("canonical-channel") + + verifiedAt := time.Now().UTC().Truncate(time.Second) + created, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-123", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + Channel: "mp", + ChannelAppID: "wx-app", + ChannelSubject: "openid-123", + }, + Issuer: stringPtr("https://issuer.example"), + VerifiedAt: &verifiedAt, + Metadata: map[string]any{"unionid": "union-123"}, + ChannelMetadata: map[string]any{"openid": "openid-123"}, + }) + s.Require().NoError(err) + s.Require().NotNil(created.Identity) + s.Require().NotNil(created.Channel) + + canonical, err := s.repo.GetUserByCanonicalIdentity(s.ctx, created.IdentityRef()) + s.Require().NoError(err) + s.Require().Equal(user.ID, canonical.User.ID) + s.Require().Equal(created.Identity.ID, canonical.Identity.ID) + s.Require().Equal("union-123", canonical.Identity.ProviderSubject) + + channel, err := s.repo.GetUserByChannelIdentity(s.ctx, *created.ChannelRef()) + s.Require().NoError(err) + s.Require().Equal(user.ID, channel.User.ID) + s.Require().Equal(created.Identity.ID, channel.Identity.ID) + s.Require().Equal(created.Channel.ID, channel.Channel.ID) +} + +func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAndRejectsOtherOwners() { + owner := s.mustCreateUser("owner") + other := s.mustCreateUser("other") + + first, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: owner.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "linuxdo-web", + ChannelSubject: "subject-1", + }, + Metadata: map[string]any{"username": "first"}, + ChannelMetadata: map[string]any{"scope": "read"}, + }) + s.Require().NoError(err) + + second, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: owner.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "linuxdo-web", + ChannelSubject: "subject-1", + }, + Metadata: map[string]any{"username": "second"}, + ChannelMetadata: map[string]any{"scope": "write"}, + }) + s.Require().NoError(err) + s.Require().Equal(first.Identity.ID, second.Identity.ID) + s.Require().Equal(first.Channel.ID, second.Channel.ID) + s.Require().Equal("second", second.Identity.Metadata["username"]) + s.Require().Equal("write", second.Channel.Metadata["scope"]) + + _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: other.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + }) + s.Require().ErrorIs(err, ErrAuthIdentityOwnershipConflict) + + _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: other.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-2", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "linuxdo-web", + ChannelSubject: "subject-1", + }, + }) + s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict) +} + +func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() { + user := s.mustCreateUser("tx-rollback") + expectedErr := errors.New("rollback") + + err := s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error { + _, err := s.repo.CreateAuthIdentity(txCtx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-rollback", + }, + }) + s.Require().NoError(err) + + inserted, err := s.repo.RecordProviderGrant(txCtx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "oidc", + GrantReason: ProviderGrantReasonFirstBind, + }) + s.Require().NoError(err) + s.Require().True(inserted) + return expectedErr + }) + s.Require().ErrorIs(err, expectedErr) + + _, err = s.repo.GetUserByCanonicalIdentity(s.ctx, AuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-rollback", + }) + s.Require().True(dbent.IsNotFound(err)) + + var count int + s.Require().NoError(integrationDB.QueryRowContext(s.ctx, ` +SELECT COUNT(*) +FROM user_provider_default_grants +WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3`, + user.ID, + "oidc", + string(ProviderGrantReasonFirstBind), + ).Scan(&count)) + s.Require().Zero(count) +} + +func (s *UserProfileIdentityRepoSuite) TestRecordProviderGrant_IsIdempotentPerReason() { + user := s.mustCreateUser("grant") + + inserted, err := s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "wechat", + GrantReason: ProviderGrantReasonFirstBind, + }) + s.Require().NoError(err) + s.Require().True(inserted) + + inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "wechat", + GrantReason: ProviderGrantReasonFirstBind, + }) + s.Require().NoError(err) + s.Require().False(inserted) + + inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "wechat", + GrantReason: ProviderGrantReasonSignup, + }) + s.Require().NoError(err) + s.Require().True(inserted) + + var count int + s.Require().NoError(integrationDB.QueryRowContext(s.ctx, ` +SELECT COUNT(*) +FROM user_provider_default_grants +WHERE user_id = $1 AND provider_type = $2`, + user.ID, + "wechat", + ).Scan(&count)) + s.Require().Equal(2, count) +} + +func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_PersistsAndLinksIdentity() { + user := s.mustCreateUser("adoption") + identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-adoption", + }, + }) + s.Require().NoError(err) + + session := s.mustCreatePendingAuthSession(identity.IdentityRef()) + + first, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptDisplayName: true, + AdoptAvatar: false, + }) + s.Require().NoError(err) + s.Require().True(first.AdoptDisplayName) + s.Require().False(first.AdoptAvatar) + s.Require().Nil(first.IdentityID) + + second, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.Identity.ID, + AdoptDisplayName: true, + AdoptAvatar: true, + }) + s.Require().NoError(err) + s.Require().Equal(first.ID, second.ID) + s.Require().NotNil(second.IdentityID) + s.Require().Equal(identity.Identity.ID, *second.IdentityID) + s.Require().True(second.AdoptAvatar) + + loaded, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, session.ID) + s.Require().NoError(err) + s.Require().Equal(second.ID, loaded.ID) + s.Require().Equal(identity.Identity.ID, *loaded.IdentityID) +} + +func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() { + user := s.mustCreateUser("avatar") + + inlineAvatar, err := s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{ + StorageProvider: "inline", + URL: "data:image/png;base64,QUJD", + ContentType: "image/png", + ByteSize: 3, + SHA256: "902fbdd2b1df0c4f70b4a5d23525e932", + }) + s.Require().NoError(err) + s.Require().Equal("inline", inlineAvatar.StorageProvider) + s.Require().Equal("data:image/png;base64,QUJD", inlineAvatar.URL) + + loadedAvatar, err := s.repo.GetUserAvatar(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().NotNil(loadedAvatar) + s.Require().Equal("image/png", loadedAvatar.ContentType) + s.Require().Equal(3, loadedAvatar.ByteSize) + + _, err = s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{ + StorageProvider: "remote_url", + URL: "https://cdn.example.com/avatar.png", + }) + s.Require().NoError(err) + + loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().NotNil(loadedAvatar) + s.Require().Equal("remote_url", loadedAvatar.StorageProvider) + s.Require().Equal("https://cdn.example.com/avatar.png", loadedAvatar.URL) + s.Require().Zero(loadedAvatar.ByteSize) + + s.Require().NoError(s.repo.DeleteUserAvatar(s.ctx, user.ID)) + loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().Nil(loadedAvatar) +} + +func (s *UserProfileIdentityRepoSuite) TestAuthIdentityMigrationReportHelpers_ListAndSummarize() { + _, err := integrationDB.ExecContext(s.ctx, ` +INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at) +VALUES + ('wechat_openid_only_requires_remediation', 'u-1', '{"user_id":1}'::jsonb, '2026-04-20T10:00:00Z'), + ('wechat_openid_only_requires_remediation', 'u-2', '{"user_id":2}'::jsonb, '2026-04-20T11:00:00Z'), + ('oidc_synthetic_email_requires_manual_recovery', 'u-3', '{"user_id":3}'::jsonb, '2026-04-20T12:00:00Z')`) + s.Require().NoError(err) + + summary, err := s.repo.SummarizeAuthIdentityMigrationReports(s.ctx) + s.Require().NoError(err) + s.Require().Equal(int64(3), summary.Total) + s.Require().Equal(int64(2), summary.ByType["wechat_openid_only_requires_remediation"]) + s.Require().Equal(int64(1), summary.ByType["oidc_synthetic_email_requires_manual_recovery"]) + + reports, err := s.repo.ListAuthIdentityMigrationReports(s.ctx, AuthIdentityMigrationReportQuery{ + ReportType: "wechat_openid_only_requires_remediation", + Limit: 10, + }) + s.Require().NoError(err) + s.Require().Len(reports, 2) + s.Require().Equal("u-2", reports[0].ReportKey) + s.Require().Equal(float64(2), reports[0].Details["user_id"]) + + report, err := s.repo.GetAuthIdentityMigrationReport(s.ctx, "oidc_synthetic_email_requires_manual_recovery", "u-3") + s.Require().NoError(err) + s.Require().Equal("u-3", report.ReportKey) + s.Require().Equal(float64(3), report.Details["user_id"]) +} + +func (s *UserProfileIdentityRepoSuite) TestUpdateUserLastLoginAndActiveAt_UsesDedicatedColumns() { + user := s.mustCreateUser("activity") + loginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC) + activeAt := loginAt.Add(5 * time.Minute) + + s.Require().NoError(s.repo.UpdateUserLastLoginAt(s.ctx, user.ID, loginAt)) + s.Require().NoError(s.repo.UpdateUserLastActiveAt(s.ctx, user.ID, activeAt)) + + var storedLoginAt sqlNullTime + var storedActiveAt sqlNullTime + s.Require().NoError(integrationDB.QueryRowContext(s.ctx, ` +SELECT last_login_at, last_active_at +FROM users +WHERE id = $1`, + user.ID, + ).Scan(&storedLoginAt, &storedActiveAt)) + s.Require().True(storedLoginAt.Valid) + s.Require().True(storedActiveAt.Valid) + s.Require().True(storedLoginAt.Time.Equal(loginAt)) + s.Require().True(storedActiveAt.Time.Equal(activeAt)) +} + +type sqlNullTime struct { + Time time.Time + Valid bool +} + +func (t *sqlNullTime) Scan(value any) error { + switch v := value.(type) { + case time.Time: + t.Time = v + t.Valid = true + return nil + case nil: + t.Time = time.Time{} + t.Valid = false + return nil + default: + return fmt.Errorf("unsupported scan type %T", value) + } +} + +func stringPtr(v string) *string { + return &v +} diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 913e1c40..0c607ecc 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -64,6 +64,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). + SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)). + SetNillableLastLoginAt(userIn.LastLoginAt). + SetNillableLastActiveAt(userIn.LastActiveAt). Save(ctx) if err != nil { return translatePersistenceError(err, nil, service.ErrEmailExists) @@ -151,6 +154,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold). SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)). SetTotalRecharged(userIn.TotalRecharged) + if userIn.SignupSource != "" { + updateOp = updateOp.SetSignupSource(userIn.SignupSource) + } + if userIn.LastLoginAt != nil { + updateOp = updateOp.SetLastLoginAt(*userIn.LastLoginAt) + } + if userIn.LastActiveAt != nil { + updateOp = updateOp.SetLastActiveAt(*userIn.LastActiveAt) + } if userIn.BalanceNotifyThreshold == nil { updateOp = updateOp.ClearBalanceNotifyThreshold() } @@ -300,6 +312,7 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) var field string defaultField := true + nullsLastField := false switch sortBy { case "email": field = dbuser.FieldEmail @@ -322,6 +335,14 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) case "created_at": field = dbuser.FieldCreatedAt defaultField = false + case "last_login_at": + field = dbuser.FieldLastLoginAt + defaultField = false + nullsLastField = true + case "last_active_at": + field = dbuser.FieldLastActiveAt + defaultField = false + nullsLastField = true default: field = dbuser.FieldID } @@ -330,11 +351,23 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) if defaultField && field == dbuser.FieldID { return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)} } + if nullsLastField { + return []func(*entsql.Selector){ + entsql.OrderByField(field, entsql.OrderNullsLast()).ToFunc(), + dbent.Asc(dbuser.FieldID), + } + } return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)} } if defaultField && field == dbuser.FieldID { return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)} } + if nullsLastField { + return []func(*entsql.Selector){ + entsql.OrderByField(field, entsql.OrderDesc(), entsql.OrderNullsLast()).ToFunc(), + dbent.Desc(dbuser.FieldID), + } + } return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)} } @@ -558,10 +591,21 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) { return } dst.ID = src.ID + dst.SignupSource = src.SignupSource + dst.LastLoginAt = src.LastLoginAt + dst.LastActiveAt = src.LastActiveAt dst.CreatedAt = src.CreatedAt dst.UpdatedAt = src.UpdatedAt } +func userSignupSourceOrDefault(signupSource string) string { + signupSource = strings.TrimSpace(signupSource) + if signupSource == "" { + return "email" + } + return signupSource +} + // marshalExtraEmails serializes notify email entries to JSON for storage. func marshalExtraEmails(entries []service.NotifyEmailEntry) string { return service.MarshalNotifyEmails(entries) diff --git a/backend/internal/repository/user_repo_sort_integration_test.go b/backend/internal/repository/user_repo_sort_integration_test.go index ab84b0e9..8abef45a 100644 --- a/backend/internal/repository/user_repo_sort_integration_test.go +++ b/backend/internal/repository/user_repo_sort_integration_test.go @@ -4,6 +4,7 @@ package repository import ( "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" @@ -36,4 +37,86 @@ func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() { s.Require().Equal(first.ID, users[1].ID) } +func (s *UserRepoSuite) TestCreateAndRead_PreservesSignupSourceAndActivityTimestamps() { + lastLoginAt := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Microsecond) + lastActiveAt := time.Now().Add(-30 * time.Minute).UTC().Truncate(time.Microsecond) + + created := s.mustCreateUser(&service.User{ + Email: "identity-meta@example.com", + SignupSource: "github", + LastLoginAt: &lastLoginAt, + LastActiveAt: &lastActiveAt, + }) + + got, err := s.repo.GetByID(s.ctx, created.ID) + s.Require().NoError(err) + s.Require().Equal("github", got.SignupSource) + s.Require().NotNil(got.LastLoginAt) + s.Require().NotNil(got.LastActiveAt) + s.Require().True(got.LastLoginAt.Equal(lastLoginAt)) + s.Require().True(got.LastActiveAt.Equal(lastActiveAt)) +} + +func (s *UserRepoSuite) TestUpdate_PersistsSignupSourceAndActivityTimestamps() { + created := s.mustCreateUser(&service.User{Email: "identity-update@example.com"}) + lastLoginAt := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Microsecond) + lastActiveAt := time.Now().Add(-15 * time.Minute).UTC().Truncate(time.Microsecond) + + created.SignupSource = "oidc" + created.LastLoginAt = &lastLoginAt + created.LastActiveAt = &lastActiveAt + + s.Require().NoError(s.repo.Update(s.ctx, created)) + + got, err := s.repo.GetByID(s.ctx, created.ID) + s.Require().NoError(err) + s.Require().Equal("oidc", got.SignupSource) + s.Require().NotNil(got.LastLoginAt) + s.Require().NotNil(got.LastActiveAt) + s.Require().True(got.LastLoginAt.Equal(lastLoginAt)) + s.Require().True(got.LastActiveAt.Equal(lastActiveAt)) +} + +func (s *UserRepoSuite) TestListWithFilters_SortByLastLoginAtDesc() { + older := time.Now().Add(-4 * time.Hour).UTC().Truncate(time.Microsecond) + newer := time.Now().Add(-1 * time.Hour).UTC().Truncate(time.Microsecond) + + s.mustCreateUser(&service.User{Email: "nil-login@example.com"}) + s.mustCreateUser(&service.User{Email: "older-login@example.com", LastLoginAt: &older}) + s.mustCreateUser(&service.User{Email: "newer-login@example.com", LastLoginAt: &newer}) + + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{ + Page: 1, + PageSize: 10, + SortBy: "last_login_at", + SortOrder: "desc", + }, service.UserListFilters{}) + s.Require().NoError(err) + s.Require().Len(users, 3) + s.Require().Equal("newer-login@example.com", users[0].Email) + s.Require().Equal("older-login@example.com", users[1].Email) + s.Require().Equal("nil-login@example.com", users[2].Email) +} + +func (s *UserRepoSuite) TestListWithFilters_SortByLastActiveAtAsc() { + earlier := time.Now().Add(-3 * time.Hour).UTC().Truncate(time.Microsecond) + later := time.Now().Add(-45 * time.Minute).UTC().Truncate(time.Microsecond) + + s.mustCreateUser(&service.User{Email: "nil-active@example.com"}) + s.mustCreateUser(&service.User{Email: "later-active@example.com", LastActiveAt: &later}) + s.mustCreateUser(&service.User{Email: "earlier-active@example.com", LastActiveAt: &earlier}) + + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{ + Page: 1, + PageSize: 10, + SortBy: "last_active_at", + SortOrder: "asc", + }, service.UserListFilters{}) + s.Require().NoError(err) + s.Require().Len(users, 3) + s.Require().Equal("earlier-active@example.com", users[0].Email) + s.Require().Equal("later-active@example.com", users[1].Email) + s.Require().Equal("nil-active@example.com", users[2].Email) +} + func TestUserRepoSortSuiteSmoke(_ *testing.T) {} diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index b686b986..e903898f 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -479,7 +479,7 @@ func TestAPIContracts(t *testing.T) { service.SettingKeyOIDCConnectRedirectURL: "", service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", - service.SettingKeyOIDCConnectUsePKCE: "false", + service.SettingKeyOIDCConnectUsePKCE: "true", service.SettingKeyOIDCConnectValidateIDToken: "true", service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256", service.SettingKeyOIDCConnectClockSkewSeconds: "120", @@ -549,7 +549,7 @@ func TestAPIContracts(t *testing.T) { "oidc_connect_redirect_url": "", "oidc_connect_frontend_redirect_url": "/auth/oidc/callback", "oidc_connect_token_auth_method": "client_secret_post", - "oidc_connect_use_pkce": false, + "oidc_connect_use_pkce": true, "oidc_connect_validate_id_token": true, "oidc_connect_allowed_signing_algs": "RS256,ES256,PS256", "oidc_connect_clock_skew_seconds": 120, diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index c143b030..911a4064 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -64,12 +64,26 @@ func RegisterAuthRoutes( }), h.Auth.ResetPassword) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) + auth.GET("/oauth/wechat/start", h.Auth.WeChatOAuthStart) + auth.GET("/oauth/wechat/callback", h.Auth.WeChatOAuthCallback) + auth.POST("/oauth/pending/exchange", + rateLimiter.LimitWithOptions("oauth-pending-exchange", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.ExchangePendingOAuthCompletion, + ) auth.POST("/oauth/linuxdo/complete-registration", rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, }), h.Auth.CompleteLinuxDoOAuthRegistration, ) + auth.POST("/oauth/wechat/complete-registration", + rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CompleteWeChatOAuthRegistration, + ) auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart) auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback) auth.POST("/oauth/oidc/complete-registration", diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index 419ddbc3..b802a9c2 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -44,6 +44,15 @@ func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, erro } func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) GetUserAvatar(context.Context, int64) (*UserAvatar, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) DeleteUserAvatar(context.Context, int64) error { + panic("unexpected") +} func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { panic("unexpected") } diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index fbc856cf..323286b0 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -62,6 +62,18 @@ func (s *userRepoStub) Delete(ctx context.Context, id int64) error { return s.deleteErr } +func (s *userRepoStub) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) { + panic("unexpected GetUserAvatar call") +} + +func (s *userRepoStub) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) { + panic("unexpected UpsertUserAvatar call") +} + +func (s *userRepoStub) DeleteUserAvatar(ctx context.Context, userID int64) error { + panic("unexpected DeleteUserAvatar call") +} + func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { panic("unexpected List call") } diff --git a/backend/internal/service/auth_pending_identity_service.go b/backend/internal/service/auth_pending_identity_service.go new file mode 100644 index 00000000..b7e86e12 --- /dev/null +++ b/backend/internal/service/auth_pending_identity_service.go @@ -0,0 +1,326 @@ +package service + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +var ( + ErrPendingAuthSessionNotFound = infraerrors.NotFound("PENDING_AUTH_SESSION_NOT_FOUND", "pending auth session not found") + ErrPendingAuthSessionExpired = infraerrors.Unauthorized("PENDING_AUTH_SESSION_EXPIRED", "pending auth session has expired") + ErrPendingAuthSessionConsumed = infraerrors.Unauthorized("PENDING_AUTH_SESSION_CONSUMED", "pending auth session has already been used") + ErrPendingAuthCodeInvalid = infraerrors.Unauthorized("PENDING_AUTH_CODE_INVALID", "pending auth completion code is invalid") + ErrPendingAuthCodeExpired = infraerrors.Unauthorized("PENDING_AUTH_CODE_EXPIRED", "pending auth completion code has expired") + ErrPendingAuthCodeConsumed = infraerrors.Unauthorized("PENDING_AUTH_CODE_CONSUMED", "pending auth completion code has already been used") + ErrPendingAuthBrowserMismatch = infraerrors.Unauthorized("PENDING_AUTH_BROWSER_MISMATCH", "pending auth completion code does not match this browser session") +) + +const ( + defaultPendingAuthTTL = 15 * time.Minute + defaultPendingAuthCompletionTTL = 5 * time.Minute +) + +type PendingAuthIdentityKey struct { + ProviderType string + ProviderKey string + ProviderSubject string +} + +type CreatePendingAuthSessionInput struct { + SessionToken string + Intent string + Identity PendingAuthIdentityKey + TargetUserID *int64 + RedirectTo string + ResolvedEmail string + RegistrationPasswordHash string + BrowserSessionKey string + UpstreamIdentityClaims map[string]any + LocalFlowState map[string]any + ExpiresAt time.Time +} + +type IssuePendingAuthCompletionCodeInput struct { + PendingAuthSessionID int64 + BrowserSessionKey string + TTL time.Duration +} + +type IssuePendingAuthCompletionCodeResult struct { + Code string + ExpiresAt time.Time +} + +type PendingIdentityAdoptionDecisionInput struct { + PendingAuthSessionID int64 + IdentityID *int64 + AdoptDisplayName bool + AdoptAvatar bool +} + +type AuthPendingIdentityService struct { + entClient *dbent.Client +} + +func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService { + return &AuthPendingIdentityService{entClient: entClient} +} + +func (s *AuthPendingIdentityService) CreatePendingSession(ctx context.Context, input CreatePendingAuthSessionInput) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + sessionToken := strings.TrimSpace(input.SessionToken) + if sessionToken == "" { + var err error + sessionToken, err = randomOpaqueToken(24) + if err != nil { + return nil, err + } + } + + expiresAt := input.ExpiresAt.UTC() + if expiresAt.IsZero() { + expiresAt = time.Now().UTC().Add(defaultPendingAuthTTL) + } + + create := s.entClient.PendingAuthSession.Create(). + SetSessionToken(sessionToken). + SetIntent(strings.TrimSpace(input.Intent)). + SetProviderType(strings.TrimSpace(input.Identity.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Identity.ProviderKey)). + SetProviderSubject(strings.TrimSpace(input.Identity.ProviderSubject)). + SetRedirectTo(strings.TrimSpace(input.RedirectTo)). + SetResolvedEmail(strings.TrimSpace(input.ResolvedEmail)). + SetRegistrationPasswordHash(strings.TrimSpace(input.RegistrationPasswordHash)). + SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)). + SetUpstreamIdentityClaims(copyPendingMap(input.UpstreamIdentityClaims)). + SetLocalFlowState(copyPendingMap(input.LocalFlowState)). + SetExpiresAt(expiresAt) + if input.TargetUserID != nil { + create = create.SetTargetUserID(*input.TargetUserID) + } + return create.Save(ctx) +} + +func (s *AuthPendingIdentityService) IssueCompletionCode(ctx context.Context, input IssuePendingAuthCompletionCodeInput) (*IssuePendingAuthCompletionCodeResult, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + session, err := s.entClient.PendingAuthSession.Get(ctx, input.PendingAuthSessionID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrPendingAuthSessionNotFound + } + return nil, err + } + + code, err := randomOpaqueToken(24) + if err != nil { + return nil, err + } + ttl := input.TTL + if ttl <= 0 { + ttl = defaultPendingAuthCompletionTTL + } + expiresAt := time.Now().UTC().Add(ttl) + + update := s.entClient.PendingAuthSession.UpdateOneID(session.ID). + SetCompletionCodeHash(hashPendingAuthCode(code)). + SetCompletionCodeExpiresAt(expiresAt) + if strings.TrimSpace(input.BrowserSessionKey) != "" { + update = update.SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)) + } + if _, err := update.Save(ctx); err != nil { + return nil, err + } + + return &IssuePendingAuthCompletionCodeResult{ + Code: code, + ExpiresAt: expiresAt, + }, nil +} + +func (s *AuthPendingIdentityService) ConsumeCompletionCode(ctx context.Context, rawCode, browserSessionKey string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + codeHash := hashPendingAuthCode(strings.TrimSpace(rawCode)) + session, err := s.entClient.PendingAuthSession.Query(). + Where(pendingauthsession.CompletionCodeHashEQ(codeHash)). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrPendingAuthCodeInvalid + } + return nil, err + } + + return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthCodeExpired, ErrPendingAuthCodeConsumed) +} + +func (s *AuthPendingIdentityService) ConsumeBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + session, err := s.getBrowserSession(ctx, sessionToken) + if err != nil { + return nil, err + } + + return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed) +} + +func (s *AuthPendingIdentityService) GetBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + session, err := s.getBrowserSession(ctx, sessionToken) + if err != nil { + return nil, err + } + if err := validatePendingSessionState(session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed); err != nil { + return nil, err + } + return session, nil +} + +func (s *AuthPendingIdentityService) getBrowserSession(ctx context.Context, sessionToken string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + sessionToken = strings.TrimSpace(sessionToken) + if sessionToken == "" { + return nil, ErrPendingAuthSessionNotFound + } + + session, err := s.entClient.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(sessionToken)). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrPendingAuthSessionNotFound + } + return nil, err + } + return session, nil +} + +func (s *AuthPendingIdentityService) consumeSession( + ctx context.Context, + session *dbent.PendingAuthSession, + browserSessionKey string, + expiredErr error, + consumedErr error, +) (*dbent.PendingAuthSession, error) { + if err := validatePendingSessionState(session, browserSessionKey, expiredErr, consumedErr); err != nil { + return nil, err + } + + now := time.Now().UTC() + updated, err := s.entClient.PendingAuthSession.UpdateOneID(session.ID). + SetConsumedAt(now). + SetCompletionCodeHash(""). + ClearCompletionCodeExpiresAt(). + Save(ctx) + if err != nil { + return nil, err + } + return updated, nil +} + +func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error { + if session == nil { + return ErrPendingAuthSessionNotFound + } + + now := time.Now().UTC() + if session.ConsumedAt != nil { + return consumedErr + } + if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) { + return expiredErr + } + if session.CompletionCodeExpiresAt != nil && now.After(*session.CompletionCodeExpiresAt) { + return expiredErr + } + if strings.TrimSpace(session.BrowserSessionKey) != "" && strings.TrimSpace(browserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) { + return ErrPendingAuthBrowserMismatch + } + return nil +} + +func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, input PendingIdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + existing, err := s.entClient.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if existing == nil { + create := s.entClient.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(input.PendingAuthSessionID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar). + SetDecidedAt(time.Now().UTC()) + if input.IdentityID != nil { + create = create.SetIdentityID(*input.IdentityID) + } + return create.Save(ctx) + } + + update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar) + if input.IdentityID != nil { + update = update.SetIdentityID(*input.IdentityID) + } + return update.Save(ctx) +} + +func copyPendingMap(in map[string]any) map[string]any { + if len(in) == 0 { + return map[string]any{} + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func randomOpaqueToken(byteLen int) (string, error) { + if byteLen <= 0 { + byteLen = 16 + } + buf := make([]byte, byteLen) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} + +func hashPendingAuthCode(code string) string { + sum := sha256.Sum256([]byte(code)) + return hex.EncodeToString(sum[:]) +} diff --git a/backend/internal/service/auth_pending_identity_service_test.go b/backend/internal/service/auth_pending_identity_service_test.go new file mode 100644 index 00000000..c69ebfd2 --- /dev/null +++ b/backend/internal/service/auth_pending_identity_service_test.go @@ -0,0 +1,224 @@ +//go:build unit + +package service + +import ( + "context" + "database/sql" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newAuthPendingIdentityServiceTestClient(t *testing.T) (*AuthPendingIdentityService, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_pending_identity_service?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + return NewAuthPendingIdentityService(client), client +} + +func TestAuthPendingIdentityService_CreatePendingSessionStoresSeparatedState(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + targetUser, err := client.User.Create(). + SetEmail("pending-target@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-123", + }, + TargetUserID: &targetUser.ID, + RedirectTo: "/profile", + ResolvedEmail: "user@example.com", + BrowserSessionKey: "browser-1", + UpstreamIdentityClaims: map[string]any{"nickname": "wx-user", "avatar_url": "https://cdn.example/avatar.png"}, + LocalFlowState: map[string]any{"step": "email_required"}, + }) + require.NoError(t, err) + require.NotEmpty(t, session.SessionToken) + require.Equal(t, "bind_current_user", session.Intent) + require.Equal(t, "wechat", session.ProviderType) + require.NotNil(t, session.TargetUserID) + require.Equal(t, targetUser.ID, *session.TargetUserID) + require.Equal(t, "wx-user", session.UpstreamIdentityClaims["nickname"]) + require.Equal(t, "email_required", session.LocalFlowState["step"]) +} + +func TestAuthPendingIdentityService_CompletionCodeIsBrowserBoundAndOneTime(t *testing.T) { + svc, _ := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + BrowserSessionKey: "browser-expected", + UpstreamIdentityClaims: map[string]any{"nickname": "linux-user"}, + LocalFlowState: map[string]any{"step": "pending"}, + }) + require.NoError(t, err) + + issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{ + PendingAuthSessionID: session.ID, + BrowserSessionKey: "browser-expected", + }) + require.NoError(t, err) + require.NotEmpty(t, issued.Code) + + _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-other") + require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch) + + consumed, err := svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected") + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) + require.Empty(t, consumed.CompletionCodeHash) + require.Nil(t, consumed.CompletionCodeExpiresAt) + + _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected") + require.ErrorIs(t, err, ErrPendingAuthCodeInvalid) +} + +func TestAuthPendingIdentityService_CompletionCodeExpires(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-1", + }, + BrowserSessionKey: "browser-expired", + }) + require.NoError(t, err) + + issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{ + PendingAuthSessionID: session.ID, + BrowserSessionKey: "browser-expired", + TTL: time.Second, + }) + require.NoError(t, err) + + _, err = client.PendingAuthSession.UpdateOneID(session.ID). + SetCompletionCodeExpiresAt(time.Now().UTC().Add(-time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expired") + require.ErrorIs(t, err, ErrPendingAuthCodeExpired) +} + +func TestAuthPendingIdentityService_UpsertAdoptionDecision(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("adoption@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + identity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat-open"). + SetProviderSubject("union-adoption"). + SetMetadata(map[string]any{}). + Save(ctx) + require.NoError(t, err) + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-adoption", + }, + }) + require.NoError(t, err) + + first, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptDisplayName: true, + AdoptAvatar: false, + }) + require.NoError(t, err) + require.True(t, first.AdoptDisplayName) + require.False(t, first.AdoptAvatar) + require.Nil(t, first.IdentityID) + + second, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.ID, + AdoptDisplayName: true, + AdoptAvatar: true, + }) + require.NoError(t, err) + require.Equal(t, first.ID, second.ID) + require.NotNil(t, second.IdentityID) + require.Equal(t, identity.ID, *second.IdentityID) + require.True(t, second.AdoptAvatar) +} + +func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) { + svc, _ := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "subject-session-token", + }, + BrowserSessionKey: "browser-session", + LocalFlowState: map[string]any{ + "completion_response": map[string]any{ + "access_token": "token", + }, + }, + }) + require.NoError(t, err) + + _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-other") + require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch) + + consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session") + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) + + _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session") + require.ErrorIs(t, err, ErrPendingAuthSessionConsumed) +} diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index fd28cd42..962009ce 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -13,6 +13,7 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -106,6 +107,13 @@ func NewAuthService( } } +func (s *AuthService) EntClient() *dbent.Client { + if s == nil { + return nil + } + return s.entClient +} + // Register 用户注册,返回token和用户 func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) { return s.RegisterWithVerification(ctx, email, password, "", "", "") @@ -205,6 +213,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err) return "", nil, ErrServiceUnavailable } + s.postAuthUserBootstrap(ctx, user, "email", true) s.assignDefaultSubscriptions(ctx, user.ID) // 标记邀请码为已使用(如果使用了邀请码) @@ -421,6 +430,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string if !user.IsActive() { return "", nil, ErrUserNotActive } + s.touchUserLogin(ctx, user.ID) // 生成JWT token token, err := s.GenerateToken(user) @@ -501,6 +511,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username } } else { user = newUser + s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) s.assignDefaultSubscriptions(ctx, user.ID) } } else { @@ -520,6 +531,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } + s.touchUserLogin(ctx, user.ID) token, err := s.GenerateToken(user) if err != nil { @@ -630,6 +642,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, ErrServiceUnavailable } user = newUser + s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) s.assignDefaultSubscriptions(ctx, user.ID) } } else { @@ -646,6 +659,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema } } else { user = newUser + s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) s.assignDefaultSubscriptions(ctx, user.ID) if invitationRedeemCode != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { @@ -670,6 +684,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } + s.touchUserLogin(ctx, user.ID) tokenPair, err := s.GenerateTokenPair(ctx, user, "") if err != nil { @@ -678,63 +693,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return tokenPair, user, nil } -// pendingOAuthTokenTTL is the validity period for pending OAuth tokens. -const pendingOAuthTokenTTL = 10 * time.Minute - -// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens. -const pendingOAuthPurpose = "pending_oauth_registration" - -type pendingOAuthClaims struct { - Email string `json:"email"` - Username string `json:"username"` - Purpose string `json:"purpose"` - jwt.RegisteredClaims -} - -// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity -// while waiting for the user to supply an invitation code. -func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) { - now := time.Now() - claims := &pendingOAuthClaims{ - Email: email, - Username: username, - Purpose: pendingOAuthPurpose, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)), - IssuedAt: jwt.NewNumericDate(now), - NotBefore: jwt.NewNumericDate(now), - }, - } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString([]byte(s.cfg.JWT.Secret)) -} - -// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity. -// Returns ErrInvalidToken when the token is invalid or expired. -func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) { - if len(tokenStr) > maxTokenLength { - return "", "", ErrInvalidToken - } - parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) - token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) { - if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) - } - return []byte(s.cfg.JWT.Secret), nil - }) - if parseErr != nil { - return "", "", ErrInvalidToken - } - claims, ok := token.Claims.(*pendingOAuthClaims) - if !ok || !token.Valid { - return "", "", ErrInvalidToken - } - if claims.Purpose != pendingOAuthPurpose { - return "", "", ErrInvalidToken - } - return claims.Email, claims.Username, nil -} - func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { return @@ -752,6 +710,95 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int } } +func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) { + if user == nil || user.ID <= 0 { + return + } + + if strings.TrimSpace(signupSource) == "" { + signupSource = "email" + } + s.updateUserSignupSource(ctx, user.ID, signupSource) + + if signupSource == "email" { + s.ensureEmailAuthIdentity(ctx, user) + } + if touchLogin { + s.touchUserLogin(ctx, user.ID) + } +} + +func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) { + if s == nil || s.entClient == nil || userID <= 0 { + return + } + if strings.TrimSpace(signupSource) == "" { + return + } + if err := s.entClient.User.UpdateOneID(userID). + SetSignupSource(signupSource). + Exec(ctx); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err) + } +} + +func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) { + if s == nil || s.entClient == nil || userID <= 0 { + return + } + now := time.Now().UTC() + if err := s.entClient.User.UpdateOneID(userID). + SetLastLoginAt(now). + SetLastActiveAt(now). + Exec(ctx); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err) + } +} + +func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) { + if s == nil || s.entClient == nil || user == nil || user.ID <= 0 { + return + } + + email := strings.ToLower(strings.TrimSpace(user.Email)) + if email == "" || isReservedEmail(email) { + return + } + + if err := s.entClient.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject(email). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{ + "source": "auth_service_dual_write", + }). + OnConflictColumns( + authidentity.FieldProviderType, + authidentity.FieldProviderKey, + authidentity.FieldProviderSubject, + ). + DoNothing(). + Exec(ctx); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err) + } +} + +func inferLegacySignupSource(email string) string { + normalized := strings.ToLower(strings.TrimSpace(email)) + switch { + case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain): + return "linuxdo" + case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain): + return "oidc" + case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain): + return "wechat" + default: + return "email" + } +} + func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error { if s.settingService == nil { return nil @@ -834,7 +881,8 @@ func randomHexString(byteLength int) (string, error) { func isReservedEmail(email string) bool { normalized := strings.ToLower(strings.TrimSpace(email)) return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) || - strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) + strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) || + strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain) } // GenerateToken 生成JWT access token diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go new file mode 100644 index 00000000..5bd2b25d --- /dev/null +++ b/backend/internal/service/auth_service_identity_sync_test.go @@ -0,0 +1,153 @@ +//go:build unit + +package service_test + +import ( + "context" + "database/sql" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/repository" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +type authIdentitySettingRepoStub struct { + values map[string]string +} + +func (s *authIdentitySettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *authIdentitySettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if v, ok := s.values[key]; ok { + return v, nil + } + return "", service.ErrSettingNotFound +} + +func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error { + panic("unexpected Set call") +} + +func (s *authIdentitySettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *authIdentitySettingRepoStub) GetAll(context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error { + panic("unexpected Delete call") +} + +func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepository, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + repo := repository.NewUserRepository(client, db) + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-auth-identity-secret", + ExpireHour: 1, + }, + Default: config.DefaultConfig{ + UserBalance: 3.5, + UserConcurrency: 2, + }, + } + settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + }, + }, cfg) + + svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, nil) + return svc, repo, client +} + +func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) { + svc, _, client := newAuthServiceWithEnt(t) + ctx := context.Background() + + token, user, err := svc.Register(ctx, "user@example.com", "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, user) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, "email", storedUser.SignupSource) + require.NotNil(t, storedUser.LastLoginAt) + require.NotNil(t, storedUser.LastActiveAt) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("user@example.com"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) + require.NotNil(t, identity.VerifiedAt) +} + +func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) { + svc, repo, client := newAuthServiceWithEnt(t) + ctx := context.Background() + + user := &service.User{ + Email: "login@example.com", + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 1, + Concurrency: 1, + } + require.NoError(t, user.SetPassword("password")) + require.NoError(t, repo.Create(ctx, user)) + + old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second) + _, err := client.User.UpdateOneID(user.ID). + SetLastLoginAt(old). + SetLastActiveAt(old). + Save(ctx) + require.NoError(t, err) + + token, gotUser, err := svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.NotNil(t, storedUser.LastLoginAt) + require.NotNil(t, storedUser.LastActiveAt) + require.True(t, storedUser.LastLoginAt.After(old)) + require.True(t, storedUser.LastActiveAt.After(old)) +} diff --git a/backend/internal/service/auth_service_pending_oauth_test.go b/backend/internal/service/auth_service_pending_oauth_test.go deleted file mode 100644 index 0472e06c..00000000 --- a/backend/internal/service/auth_service_pending_oauth_test.go +++ /dev/null @@ -1,146 +0,0 @@ -//go:build unit - -package service - -import ( - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/golang-jwt/jwt/v5" - "github.com/stretchr/testify/require" -) - -func newAuthServiceForPendingOAuthTest() *AuthService { - cfg := &config.Config{ - JWT: config.JWTConfig{ - Secret: "test-secret-pending-oauth", - ExpireHour: 1, - }, - } - return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) -} - -// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。 -func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - token, err := svc.CreatePendingOAuthToken("user@example.com", "alice") - require.NoError(t, err) - require.NotEmpty(t, token) - - email, username, err := svc.VerifyPendingOAuthToken(token) - require.NoError(t, err) - require.Equal(t, "user@example.com", email) - require.Equal(t, "alice", username) -} - -// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - // 签发一个普通 access token(JWTClaims,无 Purpose 字段) - accessToken, err := svc.GenerateToken(&User{ - ID: 1, - Email: "user@example.com", - Role: RoleUser, - }) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(accessToken) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - now := time.Now() - claims := &pendingOAuthClaims{ - Email: "user@example.com", - Username: "alice", - Purpose: "some_other_purpose", - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), - IssuedAt: jwt.NewNumericDate(now), - NotBefore: jwt.NewNumericDate(now), - }, - } - tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(tokenStr) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - now := time.Now() - claims := &pendingOAuthClaims{ - Email: "user@example.com", - Username: "alice", - Purpose: "", // 旧 token 无此字段,反序列化后为零值 - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), - IssuedAt: jwt.NewNumericDate(now), - NotBefore: jwt.NewNumericDate(now), - }, - } - tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(tokenStr) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - past := time.Now().Add(-1 * time.Hour) - claims := &pendingOAuthClaims{ - Email: "user@example.com", - Username: "alice", - Purpose: pendingOAuthPurpose, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(past), - IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)), - NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)), - }, - } - tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(tokenStr) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) { - other := NewAuthService(nil, nil, nil, nil, &config.Config{ - JWT: config.JWTConfig{Secret: "other-secret"}, - }, nil, nil, nil, nil, nil, nil) - - token, err := other.CreatePendingOAuthToken("user@example.com", "alice") - require.NoError(t, err) - - svc := newAuthServiceForPendingOAuthTest() - _, _, err = svc.VerifyPendingOAuthToken(token) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_TooLong(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - giant := make([]byte, maxTokenLength+1) - for i := range giant { - giant[i] = 'a' - } - _, _, err := svc.VerifyPendingOAuthToken(string(giant)) - require.ErrorIs(t, err, ErrInvalidToken) -} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index cb452efb..1dddf77e 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -74,6 +74,9 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" // OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。 const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid" +// WeChatConnectSyntheticEmailDomain 是 WeChat Connect 用户的合成邮箱后缀(RFC 保留域名)。 +const WeChatConnectSyntheticEmailDomain = "@wechat-connect.invalid" + // Setting keys const ( // 注册设置 @@ -153,6 +156,29 @@ const ( SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON) + // 第三方认证来源默认授予配置 + SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance" + SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency" + SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions" + SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup" + SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind" + SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance" + SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency" + SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions" + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup" + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind" + SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance" + SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency" + SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions" + SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup" + SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind" + SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance" + SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency" + SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions" + SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup" + SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind" + SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup" + // 管理员 API Key SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 6c09e354..09e60220 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -13,14 +13,30 @@ import ( "sync" "sync/atomic" "time" + + "golang.org/x/sync/singleflight" ) const ( openAIAccountScheduleLayerPreviousResponse = "previous_response_id" openAIAccountScheduleLayerSessionSticky = "session_hash" openAIAccountScheduleLayerLoadBalance = "load_balance" + openAIAdvancedSchedulerSettingKey = "openai_advanced_scheduler_enabled" ) +const ( + openAIAdvancedSchedulerSettingCacheTTL = 5 * time.Second + openAIAdvancedSchedulerSettingDBTimeout = 2 * time.Second +) + +type cachedOpenAIAdvancedSchedulerSetting struct { + enabled bool + expiresAt int64 +} + +var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSchedulerSetting +var openAIAdvancedSchedulerSettingSF singleflight.Group + type OpenAIAccountScheduleRequest struct { GroupID *int64 SessionHash string @@ -805,10 +821,56 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler return snapshot } -func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler { +func (s *OpenAIGatewayService) openAIAdvancedSchedulerSettingRepo() SettingRepository { + if s == nil || s.rateLimitService == nil || s.rateLimitService.settingService == nil { + return nil + } + return s.rateLimitService.settingService.settingRepo +} + +func (s *OpenAIGatewayService) isOpenAIAdvancedSchedulerEnabled(ctx context.Context) bool { + if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return cached.enabled + } + } + + result, _, _ := openAIAdvancedSchedulerSettingSF.Do(openAIAdvancedSchedulerSettingKey, func() (any, error) { + if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return cached.enabled, nil + } + } + + enabled := false + if repo := s.openAIAdvancedSchedulerSettingRepo(); repo != nil { + dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIAdvancedSchedulerSettingDBTimeout) + defer cancel() + + value, err := repo.GetValue(dbCtx, openAIAdvancedSchedulerSettingKey) + if err == nil { + enabled = strings.EqualFold(strings.TrimSpace(value), "true") + } + } + + openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{ + enabled: enabled, + expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(), + }) + return enabled, nil + }) + + enabled, _ := result.(bool) + return enabled +} + +func (s *OpenAIGatewayService) getOpenAIAccountScheduler(ctx context.Context) OpenAIAccountScheduler { if s == nil { return nil } + if !s.isOpenAIAdvancedSchedulerEnabled(ctx) { + return nil + } s.openaiSchedulerOnce.Do(func() { if s.openaiAccountStats == nil { s.openaiAccountStats = newOpenAIAccountRuntimeStats() @@ -820,6 +882,11 @@ func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountSchedule return s.openaiScheduler } +func resetOpenAIAdvancedSchedulerSettingCacheForTest() { + openAIAdvancedSchedulerSettingCache = atomic.Value{} + openAIAdvancedSchedulerSettingSF = singleflight.Group{} +} + func (s *OpenAIGatewayService) SelectAccountWithScheduler( ctx context.Context, groupID *int64, @@ -830,7 +897,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( requiredTransport OpenAIUpstreamTransport, ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { decision := OpenAIAccountScheduleDecision{} - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(ctx) if scheduler == nil { selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs) decision.Layer = openAIAccountScheduleLayerLoadBalance @@ -856,7 +923,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( } func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) { - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(context.Background()) if scheduler == nil { return } @@ -864,7 +931,7 @@ func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64 } func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(context.Background()) if scheduler == nil { return } @@ -872,7 +939,7 @@ func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { } func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot { - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(context.Background()) if scheduler == nil { return OpenAIAccountSchedulerMetricsSnapshot{} } diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index 088815ed..a54f2614 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -2,6 +2,7 @@ package service import ( "context" + "errors" "fmt" "math" "sync" @@ -18,6 +19,202 @@ type openAISnapshotCacheStub struct { accountsByID map[int64]*Account } +type schedulerTestOpenAIAccountRepo struct { + AccountRepository + accounts []Account +} + +func (r schedulerTestOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) { + for i := range r.accounts { + if r.accounts[i].ID == id { + return &r.accounts[i], nil + } + } + return nil, errors.New("account not found") +} + +func (r schedulerTestOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) { + var result []Account + for _, acc := range r.accounts { + if acc.Platform == platform { + result = append(result, acc) + } + } + return result, nil +} + +func (r schedulerTestOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) { + var result []Account + for _, acc := range r.accounts { + if acc.Platform == platform { + result = append(result, acc) + } + } + return result, nil +} + +func (r schedulerTestOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + return r.ListSchedulableByPlatform(ctx, platform) +} + +type schedulerTestConcurrencyCache struct { + ConcurrencyCache + loadBatchErr error + loadMap map[int64]*AccountLoadInfo + acquireResults map[int64]bool + waitCounts map[int64]int + skipDefaultLoad bool +} + +func (c schedulerTestConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + if c.acquireResults != nil { + if result, ok := c.acquireResults[accountID]; ok { + return result, nil + } + } + return true, nil +} + +func (c schedulerTestConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { + return nil +} + +func (c schedulerTestConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + if c.loadBatchErr != nil { + return nil, c.loadBatchErr + } + out := make(map[int64]*AccountLoadInfo, len(accounts)) + if c.skipDefaultLoad && c.loadMap != nil { + for _, acc := range accounts { + if load, ok := c.loadMap[acc.ID]; ok { + out[acc.ID] = load + } + } + return out, nil + } + for _, acc := range accounts { + if c.loadMap != nil { + if load, ok := c.loadMap[acc.ID]; ok { + out[acc.ID] = load + continue + } + } + out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0} + } + return out, nil +} + +func (c schedulerTestConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + if c.waitCounts != nil { + if count, ok := c.waitCounts[accountID]; ok { + return count, nil + } + } + return 0, nil +} + +type schedulerTestGatewayCache struct { + sessionBindings map[string]int64 + deletedSessions map[string]int +} + +func (c *schedulerTestGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { + if id, ok := c.sessionBindings[sessionHash]; ok { + return id, nil + } + return 0, errors.New("not found") +} + +func (c *schedulerTestGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error { + if c.sessionBindings == nil { + c.sessionBindings = make(map[string]int64) + } + c.sessionBindings[sessionHash] = accountID + return nil +} + +func (c *schedulerTestGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error { + return nil +} + +func (c *schedulerTestGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { + if c.sessionBindings == nil { + return nil + } + if c.deletedSessions == nil { + c.deletedSessions = make(map[string]int) + } + c.deletedSessions[sessionHash]++ + delete(c.sessionBindings, sessionHash) + return nil +} + +func newSchedulerTestOpenAIWSV2Config() *config.Config { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + return cfg +} + +type openAIAdvancedSchedulerSettingRepoStub struct { + values map[string]string +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + value, err := s.GetValue(ctx, key) + if err != nil { + return nil, err + } + return &Setting{Key: key, Value: value}, nil +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if s == nil || s.values == nil { + return "", ErrSettingNotFound + } + value, ok := s.values[key] + if !ok { + return "", ErrSettingNotFound + } + return value, nil +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) Set(context.Context, string, string) error { + panic("unexpected call to Set") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) { + panic("unexpected call to GetMultiple") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + panic("unexpected call to SetMultiple") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + panic("unexpected call to GetAll") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) Delete(context.Context, string) error { + panic("unexpected call to Delete") +} + +func newOpenAIAdvancedSchedulerRateLimitService(enabled string) *RateLimitService { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + repo := &openAIAdvancedSchedulerSettingRepoStub{ + values: map[string]string{}, + } + if enabled != "" { + repo.values[openAIAdvancedSchedulerSettingKey] = enabled + } + return &RateLimitService{ + settingService: NewSettingService(repo, &config.Config{}), + } +} + func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) { if len(s.snapshotAccounts) == 0 { return nil, false, nil @@ -45,6 +242,138 @@ func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int6 return &cloned, nil } +func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLegacyLoadAwareness(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10106) + accounts := []Account{ + { + ID: 36001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + }, + { + ID: 36002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = false + cache := &schedulerTestGatewayCache{} + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + store := svc.getOpenAIWSStateStore() + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_disabled_001", 36001, time.Hour)) + require.False(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx)) + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "resp_disabled_001", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(36002), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.False(t, decision.StickyPreviousHit) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10107) + accounts := []Account{ + { + ID: 37001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + { + ID: 37002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = false + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + store := svc.getOpenAIWSStateStore() + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_enabled_001", 37001, time.Hour)) + require.True(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx)) + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "resp_enabled_001", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(37001), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer) + require.True(t, decision.StickyPreviousHit) +} + +func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + svc := &OpenAIGatewayService{} + ttft := 120 + svc.ReportOpenAIAccountScheduleResult(10, true, &ttft) + svc.RecordOpenAIAccountSwitch() + + snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() + require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot) +} + func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) { ctx := context.Background() groupID := int64(10101) @@ -53,10 +382,17 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} - cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}} + cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}} snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}} snapshotService := &SchedulerSnapshotService{cache: snapshotCache} - svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})} + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, + cache: cache, + cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + schedulerSnapshot: snapshotService, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny) require.NoError(t, err) @@ -76,7 +412,12 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}} snapshotService := &SchedulerSnapshotService{cache: snapshotCache} - svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService} + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, + cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + schedulerSnapshot: snapshotService, + } account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil) require.NoError(t, err) @@ -92,18 +433,19 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} - cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}} + cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}} snapshotCache := &openAISnapshotCacheStub{ snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup}, } snapshotService := &SchedulerSnapshotService{cache: snapshotCache} svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}}, cache: cache, cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), schedulerSnapshot: snapshotService, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny) @@ -128,8 +470,9 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeReche } snapshotService := &SchedulerSnapshotService{cache: snapshotCache} svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}}, cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), schedulerSnapshot: snapshotService, } @@ -153,7 +496,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky( "openai_apikey_responses_websockets_v2_enabled": true, }, } - cache := &stubGatewayCache{} + cache := &schedulerTestGatewayCache{} cfg := &config.Config{} cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true @@ -163,10 +506,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky( cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: cfg, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } store := svc.getOpenAIWSStateStore() @@ -204,17 +548,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin Schedulable: true, Concurrency: 1, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_abc": account.ID, }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: &config.Config{}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( @@ -260,7 +605,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS Priority: 9, }, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_sticky_busy": 21001, }, @@ -273,7 +618,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ acquireResults: map[int64]bool{ 21001: false, // sticky 账号已满 21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换) @@ -288,9 +633,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, cache: cache, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -328,17 +674,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP "openai_ws_force_http": true, }, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_force_http": account.ID, }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: &config.Config{}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( @@ -387,15 +734,15 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick }, }, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_ws_only": 2201, }, } - cfg := newOpenAIWSV2TestConfig() + cfg := newSchedulerTestOpenAIWSV2Config() // 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。 - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0}, 2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5}, @@ -403,9 +750,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, cache: cache, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -445,10 +793,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, - cache: &stubGatewayCache{}, - cfg: newOpenAIWSV2TestConfig(), - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: newSchedulerTestOpenAIWSV2Config(), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( @@ -507,7 +856,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2 cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1 - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8}, 3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1}, @@ -520,9 +869,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, - cache: &stubGatewayCache{}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -559,16 +909,17 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) { Schedulable: true, Concurrency: 1, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_metrics": account.ID, }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: &config.Config{}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny) @@ -749,7 +1100,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1}, 5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1}, @@ -757,9 +1108,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, - cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{sessionBindings: map[string]int64{}}, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -905,12 +1257,14 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) { } func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + svc := &OpenAIGatewayService{} ttft := 120 svc.ReportOpenAIAccountScheduleResult(10, true, &ttft) svc.RecordOpenAIAccountSwitch() snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() - require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1)) + require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot) require.Equal(t, 7, svc.openAIWSLBTopK()) require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL()) @@ -947,7 +1301,7 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t * require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE)) require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2)) - cfg := newOpenAIWSV2TestConfig() + cfg := newSchedulerTestOpenAIWSV2Config() scheduler.service = &OpenAIGatewayService{cfg: cfg} account := &Account{ ID: 8801, diff --git a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go index c5de8203..ddafc6eb 100644 --- a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go +++ b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go @@ -38,11 +38,12 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}}, - cache: &stubGatewayCache{}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*account}}, + cache: &schedulerTestGatewayCache{}, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index 59764b29..34462a3a 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -196,12 +196,25 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo SettingHelpImageURL, SettingHelpText, SettingCancelRateLimitOn, SettingCancelRateLimitMax, SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode, + SettingPaymentVisibleMethodAlipayEnabled, SettingPaymentVisibleMethodAlipaySource, + SettingPaymentVisibleMethodWxpayEnabled, SettingPaymentVisibleMethodWxpaySource, } vals, err := s.settingRepo.GetMultiple(ctx, keys) if err != nil { return nil, fmt.Errorf("get payment config settings: %w", err) } cfg := s.parsePaymentConfig(vals) + if s.entClient != nil { + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where(paymentproviderinstance.EnabledEQ(true)). + All(ctx) + if err != nil { + return nil, fmt.Errorf("list enabled provider instances: %w", err) + } + cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, buildVisibleMethodSourceAvailability(instances)) + } else { + cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, nil) + } // Load Stripe publishable key from the first enabled Stripe provider instance cfg.StripePublishableKey = s.getStripePublishableKey(ctx) return cfg, nil @@ -234,18 +247,23 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy } if raw := vals[SettingEnabledPaymentTypes]; raw != "" { + types := make([]string, 0, len(strings.Split(raw, ","))) for _, t := range strings.Split(raw, ",") { t = strings.TrimSpace(t) if t != "" { - cfg.EnabledTypes = append(cfg.EnabledTypes, t) + types = append(types, t) } } + cfg.EnabledTypes = NormalizeVisibleMethods(types) } return cfg } // getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance. func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string { + if s.entClient == nil { + return "" + } instances, err := s.entClient.PaymentProviderInstance.Query(). Where( paymentproviderinstance.EnabledEQ(true), @@ -385,3 +403,79 @@ func pcParseInt(s string, defaultVal int) int { } return v } + +func buildVisibleMethodSourceAvailability(instances []*dbent.PaymentProviderInstance) map[string]bool { + available := make(map[string]bool, 4) + for _, inst := range instances { + switch inst.ProviderKey { + case payment.TypeAlipay: + if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipayDirect) { + available[VisibleMethodSourceOfficialAlipay] = true + } + case payment.TypeWxpay: + if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpayDirect) { + available[VisibleMethodSourceOfficialWechat] = true + } + case payment.TypeEasyPay: + for _, supportedType := range splitTypes(inst.SupportedTypes) { + switch NormalizeVisibleMethod(supportedType) { + case payment.TypeAlipay: + available[VisibleMethodSourceEasyPayAlipay] = true + case payment.TypeWxpay: + available[VisibleMethodSourceEasyPayWechat] = true + } + } + } + } + return available +} + +func applyVisibleMethodRoutingToEnabledTypes(base []string, vals map[string]string, available map[string]bool) []string { + shouldExpose := map[string]bool{ + payment.TypeAlipay: visibleMethodShouldBeExposed(payment.TypeAlipay, vals, available), + payment.TypeWxpay: visibleMethodShouldBeExposed(payment.TypeWxpay, vals, available), + } + + seen := make(map[string]struct{}, len(base)+2) + out := make([]string, 0, len(base)+2) + appendType := func(paymentType string) { + paymentType = NormalizeVisibleMethod(paymentType) + if paymentType == "" { + return + } + if _, ok := seen[paymentType]; ok { + return + } + seen[paymentType] = struct{}{} + out = append(out, paymentType) + } + + for _, paymentType := range base { + visibleMethod := NormalizeVisibleMethod(paymentType) + switch visibleMethod { + case payment.TypeAlipay, payment.TypeWxpay: + if shouldExpose[visibleMethod] { + appendType(visibleMethod) + } + default: + appendType(visibleMethod) + } + } + + for _, visibleMethod := range []string{payment.TypeAlipay, payment.TypeWxpay} { + if shouldExpose[visibleMethod] { + appendType(visibleMethod) + } + } + return out +} + +func visibleMethodShouldBeExposed(method string, vals map[string]string, available map[string]bool) bool { + enabledKey := visibleMethodEnabledSettingKey(method) + sourceKey := visibleMethodSourceSettingKey(method) + if enabledKey == "" || sourceKey == "" || vals[enabledKey] != "true" { + return false + } + source := NormalizeVisibleMethodSource(method, vals[sourceKey]) + return source != "" && available[source] +} diff --git a/backend/internal/service/payment_config_service_test.go b/backend/internal/service/payment_config_service_test.go index 027bb796..10919058 100644 --- a/backend/internal/service/payment_config_service_test.go +++ b/backend/internal/service/payment_config_service_test.go @@ -1,9 +1,17 @@ package service import ( + "context" + "database/sql" "testing" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/internal/payment" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" ) func TestPcParseFloat(t *testing.T) { @@ -163,6 +171,20 @@ func TestParsePaymentConfig(t *testing.T) { } }) + t.Run("enabled types are normalized to visible methods and deduplicated", func(t *testing.T) { + t.Parallel() + vals := map[string]string{ + SettingEnabledPaymentTypes: "alipay_direct, alipay, wxpay_direct, wxpay", + } + cfg := svc.parsePaymentConfig(vals) + if len(cfg.EnabledTypes) != 2 { + t.Fatalf("EnabledTypes len = %d, want 2", len(cfg.EnabledTypes)) + } + if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" { + t.Fatalf("EnabledTypes = %v, want [alipay wxpay]", cfg.EnabledTypes) + } + }) + t.Run("empty enabled types string", func(t *testing.T) { t.Parallel() vals := map[string]string{ @@ -204,3 +226,167 @@ func TestGetBasePaymentType(t *testing.T) { }) } } + +func TestApplyVisibleMethodRoutingToEnabledTypes(t *testing.T) { + t.Parallel() + + base := []string{"alipay", "wxpay", "stripe"} + vals := map[string]string{ + SettingPaymentVisibleMethodAlipayEnabled: "true", + SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay, + SettingPaymentVisibleMethodWxpayEnabled: "true", + SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat, + } + available := map[string]bool{ + VisibleMethodSourceOfficialAlipay: true, + VisibleMethodSourceOfficialWechat: false, + } + + got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available) + want := []string{"alipay", "stripe"} + if len(got) != len(want) { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got) + } + } +} + +func TestApplyVisibleMethodRoutingAddsConfiguredVisibleMethod(t *testing.T) { + t.Parallel() + + base := []string{"stripe"} + vals := map[string]string{ + SettingPaymentVisibleMethodAlipayEnabled: "true", + SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceEasyPayAlipay, + } + available := map[string]bool{ + VisibleMethodSourceEasyPayAlipay: true, + } + + got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available) + want := []string{"stripe", "alipay"} + if len(got) != len(want) { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got) + } + } +} + +func TestBuildVisibleMethodSourceAvailability(t *testing.T) { + t.Parallel() + + instances := []*dbent.PaymentProviderInstance{ + {ProviderKey: payment.TypeAlipay, SupportedTypes: "alipay"}, + {ProviderKey: payment.TypeEasyPay, SupportedTypes: "wxpay_direct, alipay"}, + {ProviderKey: payment.TypeWxpay, SupportedTypes: "wxpay_direct"}, + } + + got := buildVisibleMethodSourceAvailability(instances) + if !got[VisibleMethodSourceOfficialAlipay] { + t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialAlipay) + } + if !got[VisibleMethodSourceEasyPayAlipay] { + t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayAlipay) + } + if !got[VisibleMethodSourceOfficialWechat] { + t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialWechat) + } + if !got[VisibleMethodSourceEasyPayWechat] { + t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayWechat) + } +} + +func TestGetPaymentConfigAppliesVisibleMethodRouting(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("EasyPay Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + Save(ctx) + if err != nil { + t.Fatalf("create easypay instance: %v", err) + } + + svc := &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{ + values: map[string]string{ + SettingEnabledPaymentTypes: "alipay,wxpay,stripe", + SettingPaymentVisibleMethodAlipayEnabled: "true", + SettingPaymentVisibleMethodAlipaySource: "easypay", + SettingPaymentVisibleMethodWxpayEnabled: "true", + SettingPaymentVisibleMethodWxpaySource: "wxpay", + }, + }, + } + + cfg, err := svc.GetPaymentConfig(ctx) + if err != nil { + t.Fatalf("GetPaymentConfig returned error: %v", err) + } + + want := []string{payment.TypeAlipay, payment.TypeStripe} + if len(cfg.EnabledTypes) != len(want) { + t.Fatalf("EnabledTypes len = %d, want %d (%v)", len(cfg.EnabledTypes), len(want), cfg.EnabledTypes) + } + for i := range want { + if cfg.EnabledTypes[i] != want[i] { + t.Fatalf("EnabledTypes[%d] = %q, want %q (full=%v)", i, cfg.EnabledTypes[i], want[i], cfg.EnabledTypes) + } + } +} + +func newPaymentConfigServiceTestClient(t *testing.T) *dbent.Client { + t.Helper() + + db, err := sql.Open("sqlite", "file:payment_config_service?mode=memory&cache=shared") + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + + if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + t.Fatalf("enable foreign keys: %v", err) + } + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} + +type paymentConfigSettingRepoStub struct { + values map[string]string +} + +func (s *paymentConfigSettingRepoStub) Get(context.Context, string) (*Setting, error) { + return nil, nil +} +func (s *paymentConfigSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + return s.values[key], nil +} +func (s *paymentConfigSettingRepoStub) Set(context.Context, string, string) error { return nil } +func (s *paymentConfigSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + out[key] = s.values[key] + } + return out, nil +} +func (s *paymentConfigSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + return nil +} +func (s *paymentConfigSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + return s.values, nil +} +func (s *paymentConfigSettingRepoStub) Delete(context.Context, string) error { return nil } diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go new file mode 100644 index 00000000..894a8198 --- /dev/null +++ b/backend/internal/service/payment_resume_service.go @@ -0,0 +1,248 @@ +package service + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/url" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +const ( + PaymentSourceHostedRedirect = "hosted_redirect" + PaymentSourceWechatInAppResume = "wechat_in_app_resume" + + paymentResumeFallbackSigningKey = "sub2api-payment-resume" + + SettingPaymentVisibleMethodAlipaySource = "payment_visible_method_alipay_source" + SettingPaymentVisibleMethodWxpaySource = "payment_visible_method_wxpay_source" + SettingPaymentVisibleMethodAlipayEnabled = "payment_visible_method_alipay_enabled" + SettingPaymentVisibleMethodWxpayEnabled = "payment_visible_method_wxpay_enabled" + + VisibleMethodSourceOfficialAlipay = "official_alipay" + VisibleMethodSourceEasyPayAlipay = "easypay_alipay" + VisibleMethodSourceOfficialWechat = "official_wxpay" + VisibleMethodSourceEasyPayWechat = "easypay_wxpay" +) + +type ResumeTokenClaims struct { + OrderID int64 `json:"oid"` + UserID int64 `json:"uid,omitempty"` + ProviderInstanceID string `json:"pi,omitempty"` + ProviderKey string `json:"pk,omitempty"` + PaymentType string `json:"pt,omitempty"` + CanonicalReturnURL string `json:"ru,omitempty"` + IssuedAt int64 `json:"iat"` +} + +type PaymentResumeService struct { + signingKey []byte +} + +type visibleMethodLoadBalancer struct { + inner payment.LoadBalancer + configService *PaymentConfigService +} + +func NewPaymentResumeService(signingKey []byte) *PaymentResumeService { + return &PaymentResumeService{signingKey: signingKey} +} + +func NormalizeVisibleMethod(method string) string { + return payment.GetBasePaymentType(strings.TrimSpace(method)) +} + +func NormalizeVisibleMethods(methods []string) []string { + if len(methods) == 0 { + return nil + } + seen := make(map[string]struct{}, len(methods)) + out := make([]string, 0, len(methods)) + for _, method := range methods { + normalized := NormalizeVisibleMethod(method) + if normalized == "" { + continue + } + if _, ok := seen[normalized]; ok { + continue + } + seen[normalized] = struct{}{} + out = append(out, normalized) + } + return out +} + +func NormalizePaymentSource(source string) string { + switch strings.TrimSpace(strings.ToLower(source)) { + case "", PaymentSourceHostedRedirect: + return PaymentSourceHostedRedirect + case "wechat_in_app", "wxpay_resume", PaymentSourceWechatInAppResume: + return PaymentSourceWechatInAppResume + default: + return strings.TrimSpace(strings.ToLower(source)) + } +} + +func NormalizeVisibleMethodSource(method, source string) string { + switch NormalizeVisibleMethod(method) { + case payment.TypeAlipay: + switch strings.TrimSpace(strings.ToLower(source)) { + case VisibleMethodSourceOfficialAlipay, payment.TypeAlipay, payment.TypeAlipayDirect, "official": + return VisibleMethodSourceOfficialAlipay + case VisibleMethodSourceEasyPayAlipay, payment.TypeEasyPay: + return VisibleMethodSourceEasyPayAlipay + } + case payment.TypeWxpay: + switch strings.TrimSpace(strings.ToLower(source)) { + case VisibleMethodSourceOfficialWechat, payment.TypeWxpay, payment.TypeWxpayDirect, "wechat", "official": + return VisibleMethodSourceOfficialWechat + case VisibleMethodSourceEasyPayWechat, payment.TypeEasyPay: + return VisibleMethodSourceEasyPayWechat + } + } + return "" +} + +func VisibleMethodProviderKeyForSource(method, source string) (string, bool) { + switch NormalizeVisibleMethodSource(method, source) { + case VisibleMethodSourceOfficialAlipay: + return payment.TypeAlipay, NormalizeVisibleMethod(method) == payment.TypeAlipay + case VisibleMethodSourceEasyPayAlipay: + return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeAlipay + case VisibleMethodSourceOfficialWechat: + return payment.TypeWxpay, NormalizeVisibleMethod(method) == payment.TypeWxpay + case VisibleMethodSourceEasyPayWechat: + return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeWxpay + default: + return "", false + } +} + +func newVisibleMethodLoadBalancer(inner payment.LoadBalancer, configService *PaymentConfigService) payment.LoadBalancer { + if inner == nil || configService == nil || configService.settingRepo == nil { + return inner + } + return &visibleMethodLoadBalancer{inner: inner, configService: configService} +} + +func (lb *visibleMethodLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) { + return lb.inner.GetInstanceConfig(ctx, instanceID) +} + +func (lb *visibleMethodLoadBalancer) SelectInstance(ctx context.Context, providerKey string, paymentType payment.PaymentType, strategy payment.Strategy, orderAmount float64) (*payment.InstanceSelection, error) { + visibleMethod := NormalizeVisibleMethod(paymentType) + if providerKey != "" || (visibleMethod != payment.TypeAlipay && visibleMethod != payment.TypeWxpay) { + return lb.inner.SelectInstance(ctx, providerKey, paymentType, strategy, orderAmount) + } + + enabledKey := visibleMethodEnabledSettingKey(visibleMethod) + sourceKey := visibleMethodSourceSettingKey(visibleMethod) + vals, err := lb.configService.settingRepo.GetMultiple(ctx, []string{enabledKey, sourceKey}) + if err != nil { + return nil, fmt.Errorf("load visible method routing for %s: %w", visibleMethod, err) + } + if vals[enabledKey] != "true" { + return nil, fmt.Errorf("visible payment method %s is disabled", visibleMethod) + } + + targetProviderKey, ok := VisibleMethodProviderKeyForSource(visibleMethod, vals[sourceKey]) + if !ok { + return nil, fmt.Errorf("visible payment method %s has no valid source", visibleMethod) + } + return lb.inner.SelectInstance(ctx, targetProviderKey, paymentType, strategy, orderAmount) +} + +func visibleMethodEnabledSettingKey(method string) string { + switch NormalizeVisibleMethod(method) { + case payment.TypeAlipay: + return SettingPaymentVisibleMethodAlipayEnabled + case payment.TypeWxpay: + return SettingPaymentVisibleMethodWxpayEnabled + default: + return "" + } +} + +func visibleMethodSourceSettingKey(method string) string { + switch NormalizeVisibleMethod(method) { + case payment.TypeAlipay: + return SettingPaymentVisibleMethodAlipaySource + case payment.TypeWxpay: + return SettingPaymentVisibleMethodWxpaySource + default: + return "" + } +} + +func CanonicalizeReturnURL(raw string) (string, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", nil + } + parsed, err := url.Parse(raw) + if err != nil || !parsed.IsAbs() || parsed.Host == "" { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be an absolute http/https URL") + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use http or https") + } + parsed.Fragment = "" + if parsed.Path == "" { + parsed.Path = "/" + } + return parsed.String(), nil +} + +func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) { + if claims.OrderID <= 0 { + return "", fmt.Errorf("resume token requires order id") + } + if claims.IssuedAt == 0 { + claims.IssuedAt = time.Now().Unix() + } + payload, err := json.Marshal(claims) + if err != nil { + return "", fmt.Errorf("marshal resume claims: %w", err) + } + encodedPayload := base64.RawURLEncoding.EncodeToString(payload) + return encodedPayload + "." + s.sign(encodedPayload), nil +} + +func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) { + parts := strings.Split(token, ".") + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed") + } + if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch") + } + payload, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed") + } + var claims ResumeTokenClaims + if err := json.Unmarshal(payload, &claims); err != nil { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid") + } + if claims.OrderID <= 0 { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id") + } + return &claims, nil +} + +func (s *PaymentResumeService) sign(payload string) string { + key := s.signingKey + if len(key) == 0 { + key = []byte(paymentResumeFallbackSigningKey) + } + mac := hmac.New(sha256.New, key) + _, _ = mac.Write([]byte(payload)) + return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) +} diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go new file mode 100644 index 00000000..e56b4a88 --- /dev/null +++ b/backend/internal/service/payment_resume_service_test.go @@ -0,0 +1,240 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/payment" +) + +func TestNormalizeVisibleMethods(t *testing.T) { + t.Parallel() + + got := NormalizeVisibleMethods([]string{ + "alipay_direct", + "alipay", + " wxpay_direct ", + "wxpay", + "stripe", + }) + + want := []string{"alipay", "wxpay", "stripe"} + if len(got) != len(want) { + t.Fatalf("NormalizeVisibleMethods len = %d, want %d (%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("NormalizeVisibleMethods[%d] = %q, want %q (full=%v)", i, got[i], want[i], got) + } + } +} + +func TestNormalizePaymentSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expect string + }{ + {name: "empty uses default", input: "", expect: PaymentSourceHostedRedirect}, + {name: "wechat alias normalized", input: "wechat_in_app", expect: PaymentSourceWechatInAppResume}, + {name: "canonical value preserved", input: PaymentSourceWechatInAppResume, expect: PaymentSourceWechatInAppResume}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := NormalizePaymentSource(tt.input); got != tt.expect { + t.Fatalf("NormalizePaymentSource(%q) = %q, want %q", tt.input, got, tt.expect) + } + }) + } +} + +func TestCanonicalizeReturnURL(t *testing.T) { + t.Parallel() + + got, err := CanonicalizeReturnURL("https://example.com/pay/result?b=2#a") + if err != nil { + t.Fatalf("CanonicalizeReturnURL returned error: %v", err) + } + if got != "https://example.com/pay/result?b=2" { + t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/pay/result?b=2") + } +} + +func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) { + t.Parallel() + + if _, err := CanonicalizeReturnURL("/payment/result"); err == nil { + t.Fatal("CanonicalizeReturnURL should reject relative URLs") + } +} + +func TestPaymentResumeTokenRoundTrip(t *testing.T) { + t.Parallel() + + svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := svc.CreateToken(ResumeTokenClaims{ + OrderID: 42, + UserID: 7, + ProviderInstanceID: "19", + ProviderKey: "easypay", + PaymentType: "wxpay", + CanonicalReturnURL: "https://example.com/payment/result", + IssuedAt: 1234567890, + }) + if err != nil { + t.Fatalf("CreateToken returned error: %v", err) + } + + claims, err := svc.ParseToken(token) + if err != nil { + t.Fatalf("ParseToken returned error: %v", err) + } + if claims.OrderID != 42 || claims.UserID != 7 { + t.Fatalf("claims mismatch: %+v", claims) + } + if claims.ProviderInstanceID != "19" || claims.ProviderKey != "easypay" || claims.PaymentType != "wxpay" { + t.Fatalf("claims provider snapshot mismatch: %+v", claims) + } + if claims.CanonicalReturnURL != "https://example.com/payment/result" { + t.Fatalf("claims return URL = %q", claims.CanonicalReturnURL) + } +} + +func TestNormalizeVisibleMethodSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method string + input string + want string + }{ + {name: "alipay official alias", method: payment.TypeAlipay, input: "alipay", want: VisibleMethodSourceOfficialAlipay}, + {name: "alipay easypay alias", method: payment.TypeAlipay, input: "easypay", want: VisibleMethodSourceEasyPayAlipay}, + {name: "wxpay official alias", method: payment.TypeWxpay, input: "wxpay", want: VisibleMethodSourceOfficialWechat}, + {name: "wxpay easypay alias", method: payment.TypeWxpay, input: "easypay", want: VisibleMethodSourceEasyPayWechat}, + {name: "unsupported source", method: payment.TypeWxpay, input: "stripe", want: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := NormalizeVisibleMethodSource(tt.method, tt.input); got != tt.want { + t.Fatalf("NormalizeVisibleMethodSource(%q, %q) = %q, want %q", tt.method, tt.input, got, tt.want) + } + }) + } +} + +func TestVisibleMethodProviderKeyForSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method string + source string + want string + ok bool + }{ + {name: "official alipay", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialAlipay, want: payment.TypeAlipay, ok: true}, + {name: "easypay alipay", method: payment.TypeAlipay, source: VisibleMethodSourceEasyPayAlipay, want: payment.TypeEasyPay, ok: true}, + {name: "official wechat", method: payment.TypeWxpay, source: VisibleMethodSourceOfficialWechat, want: payment.TypeWxpay, ok: true}, + {name: "easypay wechat", method: payment.TypeWxpay, source: VisibleMethodSourceEasyPayWechat, want: payment.TypeEasyPay, ok: true}, + {name: "mismatched method and source", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialWechat, want: "", ok: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, ok := VisibleMethodProviderKeyForSource(tt.method, tt.source) + if got != tt.want || ok != tt.ok { + t.Fatalf("VisibleMethodProviderKeyForSource(%q, %q) = (%q, %v), want (%q, %v)", tt.method, tt.source, got, ok, tt.want, tt.ok) + } + }) + } +} + +func TestVisibleMethodLoadBalancerUsesConfiguredSource(t *testing.T) { + t.Parallel() + + inner := &captureLoadBalancer{} + configService := &PaymentConfigService{ + settingRepo: &paymentSettingRepoStub{ + values: map[string]string{ + SettingPaymentVisibleMethodAlipayEnabled: "true", + SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay, + }, + }, + } + lb := newVisibleMethodLoadBalancer(inner, configService) + + _, err := lb.SelectInstance(context.Background(), "", payment.TypeAlipay, payment.StrategyRoundRobin, 12.5) + if err != nil { + t.Fatalf("SelectInstance returned error: %v", err) + } + if inner.lastProviderKey != payment.TypeAlipay { + t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, payment.TypeAlipay) + } +} + +func TestVisibleMethodLoadBalancerRejectsDisabledVisibleMethod(t *testing.T) { + t.Parallel() + + inner := &captureLoadBalancer{} + configService := &PaymentConfigService{ + settingRepo: &paymentSettingRepoStub{ + values: map[string]string{ + SettingPaymentVisibleMethodWxpayEnabled: "false", + SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat, + }, + }, + } + lb := newVisibleMethodLoadBalancer(inner, configService) + + if _, err := lb.SelectInstance(context.Background(), "", payment.TypeWxpay, payment.StrategyRoundRobin, 9.9); err == nil { + t.Fatal("SelectInstance should reject disabled visible method") + } +} + +type paymentSettingRepoStub struct { + values map[string]string +} + +func (s *paymentSettingRepoStub) Get(context.Context, string) (*Setting, error) { return nil, nil } +func (s *paymentSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + return s.values[key], nil +} +func (s *paymentSettingRepoStub) Set(context.Context, string, string) error { return nil } +func (s *paymentSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + out[key] = s.values[key] + } + return out, nil +} +func (s *paymentSettingRepoStub) SetMultiple(context.Context, map[string]string) error { return nil } +func (s *paymentSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + return s.values, nil +} +func (s *paymentSettingRepoStub) Delete(context.Context, string) error { return nil } + +type captureLoadBalancer struct { + lastProviderKey string + lastPaymentType string +} + +func (c *captureLoadBalancer) GetInstanceConfig(context.Context, int64) (map[string]string, error) { + return map[string]string{}, nil +} + +func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey string, paymentType payment.PaymentType, _ payment.Strategy, _ float64) (*payment.InstanceSelection, error) { + c.lastProviderKey = providerKey + c.lastPaymentType = paymentType + return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil +} diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go index 6fc23f97..e897741a 100644 --- a/backend/internal/service/payment_service.go +++ b/backend/internal/service/payment_service.go @@ -65,15 +65,17 @@ func generateRandomString(n int) string { } type CreateOrderRequest struct { - UserID int64 - Amount float64 - PaymentType string - ClientIP string - IsMobile bool - SrcHost string - SrcURL string - OrderType string - PlanID int64 + UserID int64 + Amount float64 + PaymentType string + ClientIP string + IsMobile bool + SrcHost string + SrcURL string + ReturnURL string + PaymentSource string + OrderType string + PlanID int64 } type CreateOrderResponse struct { @@ -88,6 +90,7 @@ type CreateOrderResponse struct { ClientSecret string `json:"client_secret,omitempty"` ExpiresAt time.Time `json:"expires_at"` PaymentMode string `json:"payment_mode,omitempty"` + ResumeToken string `json:"resume_token,omitempty"` } type OrderListParams struct { @@ -165,10 +168,13 @@ type PaymentService struct { configService *PaymentConfigService userRepo UserRepository groupRepo GroupRepository + resumeService *PaymentResumeService } func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService { - return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo} + svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo} + svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService)) + return svc } // --- Provider Registry --- @@ -262,6 +268,20 @@ func psNilIfEmpty(s string) *string { return &s } +func (s *PaymentService) paymentResume() *PaymentResumeService { + if s.resumeService != nil { + return s.resumeService + } + return NewPaymentResumeService(psResumeSigningKey(s.configService)) +} + +func psResumeSigningKey(configService *PaymentConfigService) []byte { + if configService == nil { + return nil + } + return configService.encryptionKey +} + func psSliceContains(sl []string, s string) bool { for _, v := range sl { if v == s { diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 7f4a2eb1..de555478 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -9,6 +9,7 @@ import ( "fmt" "log/slog" "net/url" + "os" "sort" "strconv" "strings" @@ -114,6 +115,66 @@ type SettingService struct { webSearchManagerBuilder WebSearchManagerBuilder } +type ProviderDefaultGrantSettings struct { + Balance float64 + Concurrency int + Subscriptions []DefaultSubscriptionSetting + GrantOnSignup bool + GrantOnFirstBind bool +} + +type AuthSourceDefaultSettings struct { + Email ProviderDefaultGrantSettings + LinuxDo ProviderDefaultGrantSettings + OIDC ProviderDefaultGrantSettings + WeChat ProviderDefaultGrantSettings + ForceEmailOnThirdPartySignup bool +} + +type authSourceDefaultKeySet struct { + balance string + concurrency string + subscriptions string + grantOnSignup string + grantOnFirstBind string +} + +var ( + emailAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultEmailBalance, + concurrency: SettingKeyAuthSourceDefaultEmailConcurrency, + subscriptions: SettingKeyAuthSourceDefaultEmailSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultEmailGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultEmailGrantOnFirstBind, + } + linuxDoAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultLinuxDoBalance, + concurrency: SettingKeyAuthSourceDefaultLinuxDoConcurrency, + subscriptions: SettingKeyAuthSourceDefaultLinuxDoSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind, + } + oidcAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultOIDCBalance, + concurrency: SettingKeyAuthSourceDefaultOIDCConcurrency, + subscriptions: SettingKeyAuthSourceDefaultOIDCSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultOIDCGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind, + } + weChatAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultWeChatBalance, + concurrency: SettingKeyAuthSourceDefaultWeChatConcurrency, + subscriptions: SettingKeyAuthSourceDefaultWeChatSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind, + } +) + +const ( + defaultAuthSourceBalance = 0 + defaultAuthSourceConcurrency = 5 +) + // NewSettingService 创建系统设置服务实例 func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *SettingService { return &SettingService{ @@ -212,6 +273,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings if oidcProviderName == "" { oidcProviderName = "OIDC" } + weChatEnabled := isWeChatOAuthConfigured() // Password reset requires email verification to be enabled emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" @@ -254,6 +316,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomEndpoints: settings[SettingKeyCustomEndpoints], LinuxDoOAuthEnabled: linuxDoEnabled, + WeChatOAuthEnabled: weChatEnabled, BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", PaymentEnabled: settings[SettingPaymentEnabled] == "true", OIDCOAuthEnabled: oidcEnabled, @@ -310,6 +373,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any CustomMenuItems json.RawMessage `json:"custom_menu_items"` CustomEndpoints json.RawMessage `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"` PaymentEnabled bool `json:"payment_enabled"` OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` @@ -344,6 +408,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + WeChatOAuthEnabled: settings.WeChatOAuthEnabled, BackendModeEnabled: settings.BackendModeEnabled, PaymentEnabled: settings.PaymentEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled, @@ -392,6 +457,14 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage { return result } +func isWeChatOAuthConfigured() bool { + openConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) != "" && + strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) != "" + mpConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) != "" && + strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) != "" + return openConfigured || mpConfigured +} + // safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]". func safeRawJSONArray(raw string) json.RawMessage { raw = strings.TrimSpace(raw) @@ -919,6 +992,74 @@ func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultS return parseDefaultSubscriptions(value) } +func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*AuthSourceDefaultSettings, error) { + keys := []string{ + SettingKeyAuthSourceDefaultEmailBalance, + SettingKeyAuthSourceDefaultEmailConcurrency, + SettingKeyAuthSourceDefaultEmailSubscriptions, + SettingKeyAuthSourceDefaultEmailGrantOnSignup, + SettingKeyAuthSourceDefaultEmailGrantOnFirstBind, + SettingKeyAuthSourceDefaultLinuxDoBalance, + SettingKeyAuthSourceDefaultLinuxDoConcurrency, + SettingKeyAuthSourceDefaultLinuxDoSubscriptions, + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup, + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind, + SettingKeyAuthSourceDefaultOIDCBalance, + SettingKeyAuthSourceDefaultOIDCConcurrency, + SettingKeyAuthSourceDefaultOIDCSubscriptions, + SettingKeyAuthSourceDefaultOIDCGrantOnSignup, + SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind, + SettingKeyAuthSourceDefaultWeChatBalance, + SettingKeyAuthSourceDefaultWeChatConcurrency, + SettingKeyAuthSourceDefaultWeChatSubscriptions, + SettingKeyAuthSourceDefaultWeChatGrantOnSignup, + SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind, + SettingKeyForceEmailOnThirdPartySignup, + } + + settings, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return nil, fmt.Errorf("get auth source default settings: %w", err) + } + + return &AuthSourceDefaultSettings{ + Email: parseProviderDefaultGrantSettings(settings, emailAuthSourceDefaultKeys), + LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys), + OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys), + WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys), + ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true", + }, nil +} + +func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error { + if settings == nil { + return nil + } + + for _, subscriptions := range [][]DefaultSubscriptionSetting{ + settings.Email.Subscriptions, + settings.LinuxDo.Subscriptions, + settings.OIDC.Subscriptions, + settings.WeChat.Subscriptions, + } { + if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil { + return err + } + } + + updates := make(map[string]string, 21) + writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email) + writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo) + writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC) + writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat) + updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup) + + if err := s.settingRepo.SetMultiple(ctx, updates); err != nil { + return fmt.Errorf("update auth source default settings: %w", err) + } + return nil +} + // InitializeDefaultSettings 初始化默认设置 func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 检查是否已有设置 @@ -933,25 +1074,46 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 初始化默认设置 defaults := map[string]string{ - SettingKeyRegistrationEnabled: "true", - SettingKeyEmailVerifyEnabled: "false", - SettingKeyRegistrationEmailSuffixWhitelist: "[]", - SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 - SettingKeySiteName: "Sub2API", - SettingKeySiteLogo: "", - SettingKeyPurchaseSubscriptionEnabled: "false", - SettingKeyPurchaseSubscriptionURL: "", - SettingKeyTableDefaultPageSize: "20", - SettingKeyTablePageSizeOptions: "[10,20,50,100]", - SettingKeyCustomMenuItems: "[]", - SettingKeyCustomEndpoints: "[]", - SettingKeyOIDCConnectEnabled: "false", - SettingKeyOIDCConnectProviderName: "OIDC", - SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), - SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), - SettingKeyDefaultSubscriptions: "[]", - SettingKeySMTPPort: "587", - SettingKeySMTPUseTLS: "false", + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "false", + SettingKeyRegistrationEmailSuffixWhitelist: "[]", + SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 + SettingKeySiteName: "Sub2API", + SettingKeySiteLogo: "", + SettingKeyPurchaseSubscriptionEnabled: "false", + SettingKeyPurchaseSubscriptionURL: "", + SettingKeyTableDefaultPageSize: "20", + SettingKeyTablePageSizeOptions: "[10,20,50,100]", + SettingKeyCustomMenuItems: "[]", + SettingKeyCustomEndpoints: "[]", + SettingKeyOIDCConnectEnabled: "false", + SettingKeyOIDCConnectProviderName: "OIDC", + SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), + SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), + SettingKeyDefaultSubscriptions: "[]", + SettingKeyAuthSourceDefaultEmailBalance: "0", + SettingKeyAuthSourceDefaultEmailConcurrency: "5", + SettingKeyAuthSourceDefaultEmailSubscriptions: "[]", + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", + SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultLinuxDoBalance: "0", + SettingKeyAuthSourceDefaultLinuxDoConcurrency: "5", + SettingKeyAuthSourceDefaultLinuxDoSubscriptions: "[]", + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true", + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultOIDCBalance: "0", + SettingKeyAuthSourceDefaultOIDCConcurrency: "5", + SettingKeyAuthSourceDefaultOIDCSubscriptions: "[]", + SettingKeyAuthSourceDefaultOIDCGrantOnSignup: "true", + SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultWeChatBalance: "0", + SettingKeyAuthSourceDefaultWeChatConcurrency: "5", + SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]", + SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "true", + SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false", + SettingKeyForceEmailOnThirdPartySignup: "false", + SettingKeySMTPPort: "587", + SettingKeySMTPUseTLS: "false", // Model fallback defaults SettingKeyEnableModelFallback: "false", SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022", @@ -1164,6 +1326,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } else { result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken } + result.OIDCConnectUsePKCE = true + result.OIDCConnectValidateIDToken = true if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v) } else { @@ -1317,6 +1481,51 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting { return normalized } +func parseProviderDefaultGrantSettings(settings map[string]string, keys authSourceDefaultKeySet) ProviderDefaultGrantSettings { + result := ProviderDefaultGrantSettings{ + Balance: defaultAuthSourceBalance, + Concurrency: defaultAuthSourceConcurrency, + Subscriptions: []DefaultSubscriptionSetting{}, + GrantOnSignup: true, + GrantOnFirstBind: false, + } + + if v, err := strconv.ParseFloat(strings.TrimSpace(settings[keys.balance]), 64); err == nil { + result.Balance = v + } + if v, err := strconv.Atoi(strings.TrimSpace(settings[keys.concurrency])); err == nil { + result.Concurrency = v + } + if items := parseDefaultSubscriptions(settings[keys.subscriptions]); items != nil { + result.Subscriptions = items + } + if raw, ok := settings[keys.grantOnSignup]; ok { + result.GrantOnSignup = raw == "true" + } + if raw, ok := settings[keys.grantOnFirstBind]; ok { + result.GrantOnFirstBind = raw == "true" + } + + return result +} + +func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSourceDefaultKeySet, settings ProviderDefaultGrantSettings) { + updates[keys.balance] = strconv.FormatFloat(settings.Balance, 'f', 8, 64) + updates[keys.concurrency] = strconv.Itoa(settings.Concurrency) + + subscriptions := settings.Subscriptions + if subscriptions == nil { + subscriptions = []DefaultSubscriptionSetting{} + } + raw, err := json.Marshal(subscriptions) + if err != nil { + raw = []byte("[]") + } + updates[keys.subscriptions] = string(raw) + updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup) + updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind) +} + func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) { defaultPageSize := 20 if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil { @@ -1539,6 +1748,7 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { effective.RedirectURL = strings.TrimSpace(v) } + effective.UsePKCE = true if !effective.Enabled { return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") @@ -1587,9 +1797,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured") } case "none": - if !effective.UsePKCE { - return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none") - } default: return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid") } @@ -1737,6 +1944,8 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config. if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { effective.ValidateIDToken = raw == "true" } + effective.UsePKCE = true + effective.ValidateIDToken = true if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { effective.AllowedSigningAlgs = strings.TrimSpace(v) } @@ -1864,9 +2073,6 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config. return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured") } case "none": - if !effective.UsePKCE { - return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none") - } default: return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid") } diff --git a/backend/internal/service/setting_service_auth_source_defaults_test.go b/backend/internal/service/setting_service_auth_source_defaults_test.go new file mode 100644 index 00000000..097bf604 --- /dev/null +++ b/backend/internal/service/setting_service_auth_source_defaults_test.go @@ -0,0 +1,136 @@ +//go:build unit + +package service + +import ( + "context" + "encoding/json" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type authSourceDefaultsRepoStub struct { + values map[string]string + updates map[string]string +} + +func (s *authSourceDefaultsRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *authSourceDefaultsRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *authSourceDefaultsRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *authSourceDefaultsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *authSourceDefaultsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + s.updates = make(map[string]string, len(settings)) + for key, value := range settings { + s.updates[key] = value + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + } + return nil +} + +func (s *authSourceDefaultsRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *authSourceDefaultsRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingService_GetAuthSourceDefaultSettings_ParsesValuesAndDefaults(t *testing.T) { + repo := &authSourceDefaultsRepoStub{ + values: map[string]string{ + SettingKeyAuthSourceDefaultEmailBalance: "12.5", + SettingKeyAuthSourceDefaultEmailConcurrency: "7", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "true", + SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := NewSettingService(repo, &config.Config{}) + + got, err := svc.GetAuthSourceDefaultSettings(context.Background()) + require.NoError(t, err) + require.Equal(t, 12.5, got.Email.Balance) + require.Equal(t, 7, got.Email.Concurrency) + require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 11, ValidityDays: 30}}, got.Email.Subscriptions) + require.False(t, got.Email.GrantOnSignup) + require.False(t, got.Email.GrantOnFirstBind) + require.Equal(t, 0.0, got.LinuxDo.Balance) + require.Equal(t, 5, got.LinuxDo.Concurrency) + require.Equal(t, []DefaultSubscriptionSetting{}, got.LinuxDo.Subscriptions) + require.True(t, got.LinuxDo.GrantOnSignup) + require.True(t, got.LinuxDo.GrantOnFirstBind) + require.Equal(t, 5, got.OIDC.Concurrency) + require.Equal(t, 5, got.WeChat.Concurrency) + require.True(t, got.ForceEmailOnThirdPartySignup) +} + +func TestSettingService_UpdateAuthSourceDefaultSettings_PersistsAllKeys(t *testing.T) { + repo := &authSourceDefaultsRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateAuthSourceDefaultSettings(context.Background(), &AuthSourceDefaultSettings{ + Email: ProviderDefaultGrantSettings{ + Balance: 1.25, + Concurrency: 3, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 14}}, + GrantOnSignup: false, + GrantOnFirstBind: true, + }, + LinuxDo: ProviderDefaultGrantSettings{ + Balance: 2, + Concurrency: 4, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 22, ValidityDays: 30}}, + GrantOnSignup: true, + GrantOnFirstBind: false, + }, + OIDC: ProviderDefaultGrantSettings{ + Balance: 3, + Concurrency: 5, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 23, ValidityDays: 60}}, + GrantOnSignup: true, + GrantOnFirstBind: true, + }, + WeChat: ProviderDefaultGrantSettings{ + Balance: 4, + Concurrency: 6, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, + GrantOnSignup: false, + GrantOnFirstBind: false, + }, + ForceEmailOnThirdPartySignup: true, + }) + require.NoError(t, err) + require.Equal(t, "1.25000000", repo.updates[SettingKeyAuthSourceDefaultEmailBalance]) + require.Equal(t, "3", repo.updates[SettingKeyAuthSourceDefaultEmailConcurrency]) + require.Equal(t, "false", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnSignup]) + require.Equal(t, "true", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnFirstBind]) + require.Equal(t, "true", repo.updates[SettingKeyForceEmailOnThirdPartySignup]) + + var got []DefaultSubscriptionSetting + require.NoError(t, json.Unmarshal([]byte(repo.updates[SettingKeyAuthSourceDefaultWeChatSubscriptions]), &got)) + require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, got) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index ab2eb274..e991ebef 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -152,6 +152,7 @@ type PublicSettings struct { CustomEndpoints string // JSON array of custom endpoints LinuxDoOAuthEnabled bool + WeChatOAuthEnabled bool BackendModeEnabled bool PaymentEnabled bool OIDCOAuthEnabled bool diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index 59f8aa6b..d8b5325c 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -7,19 +7,27 @@ import ( ) type User struct { - ID int64 - Email string - Username string - Notes string - PasswordHash string - Role string - Balance float64 - Concurrency int - Status string - AllowedGroups []int64 - TokenVersion int64 // Incremented on password change to invalidate existing tokens - CreatedAt time.Time - UpdatedAt time.Time + ID int64 + Email string + Username string + Notes string + AvatarURL string + AvatarSource string + AvatarMIME string + AvatarByteSize int + AvatarSHA256 string + PasswordHash string + Role string + Balance float64 + Concurrency int + Status string + AllowedGroups []int64 + TokenVersion int64 // Incremented on password change to invalidate existing tokens + SignupSource string + LastLoginAt *time.Time + LastActiveAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time // GroupRates 用户专属分组倍率配置 // map[groupID]rateMultiplier diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 3490e804..2f6d9427 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -2,9 +2,13 @@ package service import ( "context" + "crypto/sha256" "crypto/subtle" + "encoding/base64" + "encoding/hex" "fmt" "log/slog" + "net/url" "strings" "time" @@ -17,10 +21,14 @@ var ( ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later") + ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL") + ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller") + ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image") ) const ( - maxNotifyEmails = 3 // Maximum number of notification emails per user + maxNotifyEmails = 3 // Maximum number of notification emails per user + maxInlineAvatarBytes = 100 * 1024 // User-level rate limiting for notify email verification codes notifyCodeUserRateLimit = 5 @@ -47,6 +55,9 @@ type UserRepository interface { GetFirstAdmin(ctx context.Context) (*User, error) Update(ctx context.Context, user *User) error Delete(ctx context.Context, id int64) error + GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) + UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) + DeleteUserAvatar(ctx context.Context, userID int64) error List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) @@ -71,11 +82,30 @@ type UserRepository interface { type UpdateProfileRequest struct { Email *string `json:"email"` Username *string `json:"username"` + AvatarURL *string `json:"avatar_url"` Concurrency *int `json:"concurrency"` BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` } +type UserAvatar struct { + StorageProvider string + StorageKey string + URL string + ContentType string + ByteSize int + SHA256 string +} + +type UpsertUserAvatarInput struct { + StorageProvider string + StorageKey string + URL string + ContentType string + ByteSize int + SHA256 string +} + // ChangePasswordRequest 修改密码请求 type ChangePasswordRequest struct { CurrentPassword string `json:"current_password"` @@ -115,6 +145,9 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro if err != nil { return nil, fmt.Errorf("get user: %w", err) } + if err := s.hydrateUserAvatar(ctx, user); err != nil { + return nil, fmt.Errorf("get user avatar: %w", err) + } return user, nil } @@ -143,6 +176,27 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat user.Username = *req.Username } + if req.AvatarURL != nil { + avatarValue := strings.TrimSpace(*req.AvatarURL) + switch { + case avatarValue == "": + if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil { + return nil, fmt.Errorf("delete avatar: %w", err) + } + applyUserAvatar(user, nil) + default: + avatarInput, err := normalizeUserAvatarInput(avatarValue) + if err != nil { + return nil, err + } + avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput) + if err != nil { + return nil, fmt.Errorf("upsert avatar: %w", err) + } + applyUserAvatar(user, avatar) + } + } + if req.Concurrency != nil { user.Concurrency = *req.Concurrency } @@ -168,6 +222,87 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat return user, nil } +func applyUserAvatar(user *User, avatar *UserAvatar) { + if user == nil { + return + } + if avatar == nil { + user.AvatarURL = "" + user.AvatarSource = "" + user.AvatarMIME = "" + user.AvatarByteSize = 0 + user.AvatarSHA256 = "" + return + } + + user.AvatarURL = avatar.URL + user.AvatarSource = avatar.StorageProvider + user.AvatarMIME = avatar.ContentType + user.AvatarByteSize = avatar.ByteSize + user.AvatarSHA256 = avatar.SHA256 +} + +func normalizeUserAvatarInput(raw string) (UpsertUserAvatarInput, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if strings.HasPrefix(raw, "data:") { + return normalizeInlineUserAvatarInput(raw) + } + + parsed, err := url.Parse(raw) + if err != nil || parsed == nil { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if !strings.EqualFold(parsed.Scheme, "http") && !strings.EqualFold(parsed.Scheme, "https") { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if strings.TrimSpace(parsed.Host) == "" { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + + return UpsertUserAvatarInput{ + StorageProvider: "remote_url", + URL: raw, + }, nil +} + +func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) { + body := strings.TrimPrefix(raw, "data:") + meta, encoded, ok := strings.Cut(body, ",") + if !ok { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + meta = strings.TrimSpace(meta) + encoded = strings.TrimSpace(encoded) + if !strings.HasSuffix(strings.ToLower(meta), ";base64") { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + + contentType := strings.TrimSpace(meta[:len(meta)-len(";base64")]) + if contentType == "" || !strings.HasPrefix(strings.ToLower(contentType), "image/") { + return UpsertUserAvatarInput{}, ErrAvatarNotImage + } + + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if len(decoded) > maxInlineAvatarBytes { + return UpsertUserAvatarInput{}, ErrAvatarTooLarge + } + + sum := sha256.Sum256(decoded) + return UpsertUserAvatarInput{ + StorageProvider: "inline", + URL: raw, + ContentType: contentType, + ByteSize: len(decoded), + SHA256: hex.EncodeToString(sum[:]), + }, nil +} + // ChangePassword 修改密码 // Security: Increments TokenVersion to invalidate all existing JWT tokens func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error { @@ -202,9 +337,25 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) { if err != nil { return nil, fmt.Errorf("get user: %w", err) } + if err := s.hydrateUserAvatar(ctx, user); err != nil { + return nil, fmt.Errorf("get user avatar: %w", err) + } return user, nil } +func (s *UserService) hydrateUserAvatar(ctx context.Context, user *User) error { + if s == nil || s.userRepo == nil || user == nil || user.ID == 0 { + return nil + } + + avatar, err := s.userRepo.GetUserAvatar(ctx, user.ID) + if err != nil { + return err + } + applyUserAvatar(user, avatar) + return nil +} + // List 获取用户列表(管理员功能) func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { users, pagination, err := s.userRepo.List(ctx, params) diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index a998d5f4..7d63bb36 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -4,6 +4,9 @@ package service import ( "context" + "crypto/sha256" + "encoding/base64" + "encoding/hex" "errors" "sync" "sync/atomic" @@ -19,14 +22,65 @@ import ( type mockUserRepo struct { updateBalanceErr error updateBalanceFn func(ctx context.Context, id int64, amount float64) error + getByIDUser *User + getByIDErr error + updateFn func(ctx context.Context, user *User) error + updateCalls int + upsertAvatarFn func(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) + upsertAvatarArgs []UpsertUserAvatarInput + deleteAvatarFn func(ctx context.Context, userID int64) error + deleteAvatarIDs []int64 + getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error) } -func (m *mockUserRepo) Create(context.Context, *User) error { return nil } -func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil } +func (m *mockUserRepo) Create(context.Context, *User) error { return nil } +func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { + if m.getByIDErr != nil { + return nil, m.getByIDErr + } + if m.getByIDUser != nil { + cloned := *m.getByIDUser + return &cloned, nil + } + return &User{}, nil +} func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil } func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil } -func (m *mockUserRepo) Update(context.Context, *User) error { return nil } -func (m *mockUserRepo) Delete(context.Context, int64) error { return nil } +func (m *mockUserRepo) Update(ctx context.Context, user *User) error { + m.updateCalls++ + if m.updateFn != nil { + return m.updateFn(ctx, user) + } + return nil +} +func (m *mockUserRepo) Delete(context.Context, int64) error { return nil } +func (m *mockUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) { + if m.getAvatarFn != nil { + return m.getAvatarFn(ctx, userID) + } + return nil, nil +} +func (m *mockUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) { + m.upsertAvatarArgs = append(m.upsertAvatarArgs, input) + if m.upsertAvatarFn != nil { + return m.upsertAvatarFn(ctx, userID, input) + } + return &UserAvatar{ + StorageProvider: input.StorageProvider, + StorageKey: input.StorageKey, + URL: input.URL, + ContentType: input.ContentType, + ByteSize: input.ByteSize, + SHA256: input.SHA256, + }, nil +} +func (m *mockUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error { + m.deleteAvatarIDs = append(m.deleteAvatarIDs, userID) + if m.deleteAvatarFn != nil { + return m.deleteAvatarFn(ctx, userID) + } + return nil +} func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { return nil, nil, nil } @@ -200,3 +254,121 @@ func TestNewUserService_FieldsAssignment(t *testing.T) { require.Equal(t, auth, svc.authCacheInvalidator) require.Equal(t, cache, svc.billingCache) } + +func TestUpdateProfile_StoresInlineAvatarWithinLimit(t *testing.T) { + raw := []byte("small-avatar") + dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw) + expectedSum := sha256.Sum256(raw) + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 7, + Email: "avatar@example.com", + Username: "avatar-user", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + updated, err := svc.UpdateProfile(context.Background(), 7, UpdateProfileRequest{ + AvatarURL: &dataURL, + }) + require.NoError(t, err) + require.Len(t, repo.upsertAvatarArgs, 1) + require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider) + require.Equal(t, "image/png", repo.upsertAvatarArgs[0].ContentType) + require.Equal(t, len(raw), repo.upsertAvatarArgs[0].ByteSize) + require.Equal(t, hex.EncodeToString(expectedSum[:]), repo.upsertAvatarArgs[0].SHA256) + require.Equal(t, dataURL, updated.AvatarURL) + require.Equal(t, "inline", updated.AvatarSource) + require.Equal(t, "image/png", updated.AvatarMIME) + require.Equal(t, len(raw), updated.AvatarByteSize) + require.Equal(t, hex.EncodeToString(expectedSum[:]), updated.AvatarSHA256) +} + +func TestUpdateProfile_RejectsInlineAvatarOverLimit(t *testing.T) { + raw := make([]byte, maxInlineAvatarBytes+1) + dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw) + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 8, + Email: "large-avatar@example.com", + Username: "too-large", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + _, err := svc.UpdateProfile(context.Background(), 8, UpdateProfileRequest{ + AvatarURL: &dataURL, + }) + require.ErrorIs(t, err, ErrAvatarTooLarge) + require.Empty(t, repo.upsertAvatarArgs) + require.Empty(t, repo.deleteAvatarIDs) + require.Zero(t, repo.updateCalls) +} + +func TestUpdateProfile_StoresRemoteAvatarURL(t *testing.T) { + remoteURL := "https://cdn.example.com/avatar.png" + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 9, + Email: "remote-avatar@example.com", + Username: "remote-avatar", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + updated, err := svc.UpdateProfile(context.Background(), 9, UpdateProfileRequest{ + AvatarURL: &remoteURL, + }) + require.NoError(t, err) + require.Len(t, repo.upsertAvatarArgs, 1) + require.Equal(t, "remote_url", repo.upsertAvatarArgs[0].StorageProvider) + require.Equal(t, remoteURL, repo.upsertAvatarArgs[0].URL) + require.Equal(t, remoteURL, updated.AvatarURL) + require.Equal(t, "remote_url", updated.AvatarSource) + require.Zero(t, updated.AvatarByteSize) +} + +func TestUpdateProfile_DeletesAvatarOnEmptyString(t *testing.T) { + empty := "" + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 10, + Email: "delete-avatar@example.com", + Username: "delete-avatar", + AvatarURL: "https://cdn.example.com/old.png", + AvatarSource: "remote_url", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + updated, err := svc.UpdateProfile(context.Background(), 10, UpdateProfileRequest{ + AvatarURL: &empty, + }) + require.NoError(t, err) + require.Equal(t, []int64{10}, repo.deleteAvatarIDs) + require.Empty(t, repo.upsertAvatarArgs) + require.Empty(t, updated.AvatarURL) + require.Empty(t, updated.AvatarSource) +} + +func TestGetProfile_HydratesAvatarFromRepository(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 12, + Email: "profile-avatar@example.com", + Username: "profile-avatar", + }, + getAvatarFn: func(context.Context, int64) (*UserAvatar, error) { + return &UserAvatar{ + StorageProvider: "remote_url", + URL: "https://cdn.example.com/profile.png", + }, nil + }, + } + svc := NewUserService(repo, nil, nil, nil) + + user, err := svc.GetProfile(context.Background(), 12) + require.NoError(t, err) + require.Equal(t, "https://cdn.example.com/profile.png", user.AvatarURL) + require.Equal(t, "remote_url", user.AvatarSource) +} diff --git a/backend/migrations/108_auth_identity_foundation_core.sql b/backend/migrations/108_auth_identity_foundation_core.sql new file mode 100644 index 00000000..117e3ca3 --- /dev/null +++ b/backend/migrations/108_auth_identity_foundation_core.sql @@ -0,0 +1,141 @@ +ALTER TABLE users +ADD COLUMN IF NOT EXISTS signup_source VARCHAR(20) NOT NULL DEFAULT 'email', +ADD COLUMN IF NOT EXISTS last_login_at TIMESTAMPTZ NULL, +ADD COLUMN IF NOT EXISTS last_active_at TIMESTAMPTZ NULL; + +UPDATE users +SET signup_source = 'email' +WHERE signup_source IS NULL OR signup_source = ''; + +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'users_signup_source_check' + ) THEN + ALTER TABLE users + ADD CONSTRAINT users_signup_source_check + CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc')); + END IF; +END $$; + +CREATE TABLE IF NOT EXISTS auth_identities ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + provider_type VARCHAR(20) NOT NULL, + provider_key TEXT NOT NULL, + provider_subject TEXT NOT NULL, + verified_at TIMESTAMPTZ NULL, + issuer TEXT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT auth_identities_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS auth_identities_provider_subject_key + ON auth_identities (provider_type, provider_key, provider_subject); + +CREATE INDEX IF NOT EXISTS auth_identities_user_id_idx + ON auth_identities (user_id); + +CREATE INDEX IF NOT EXISTS auth_identities_user_provider_idx + ON auth_identities (user_id, provider_type); + +CREATE TABLE IF NOT EXISTS auth_identity_channels ( + id BIGSERIAL PRIMARY KEY, + identity_id BIGINT NOT NULL REFERENCES auth_identities(id) ON DELETE CASCADE, + provider_type VARCHAR(20) NOT NULL, + provider_key TEXT NOT NULL, + channel VARCHAR(20) NOT NULL, + channel_app_id TEXT NOT NULL, + channel_subject TEXT NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT auth_identity_channels_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_channels_channel_key + ON auth_identity_channels (provider_type, provider_key, channel, channel_app_id, channel_subject); + +CREATE INDEX IF NOT EXISTS auth_identity_channels_identity_id_idx + ON auth_identity_channels (identity_id); + +CREATE TABLE IF NOT EXISTS pending_auth_sessions ( + id BIGSERIAL PRIMARY KEY, + session_token VARCHAR(255) NOT NULL, + intent VARCHAR(40) NOT NULL, + provider_type VARCHAR(20) NOT NULL, + provider_key TEXT NOT NULL, + provider_subject TEXT NOT NULL, + target_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL, + redirect_to TEXT NOT NULL DEFAULT '', + resolved_email TEXT NOT NULL DEFAULT '', + registration_password_hash TEXT NOT NULL DEFAULT '', + upstream_identity_claims JSONB NOT NULL DEFAULT '{}'::jsonb, + local_flow_state JSONB NOT NULL DEFAULT '{}'::jsonb, + browser_session_key TEXT NOT NULL DEFAULT '', + completion_code_hash TEXT NOT NULL DEFAULT '', + completion_code_expires_at TIMESTAMPTZ NULL, + email_verified_at TIMESTAMPTZ NULL, + password_verified_at TIMESTAMPTZ NULL, + totp_verified_at TIMESTAMPTZ NULL, + expires_at TIMESTAMPTZ NOT NULL, + consumed_at TIMESTAMPTZ NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT pending_auth_sessions_intent_check + CHECK (intent IN ('login', 'bind_current_user', 'adopt_existing_user_by_email')), + CONSTRAINT pending_auth_sessions_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS pending_auth_sessions_session_token_key + ON pending_auth_sessions (session_token); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_target_user_id_idx + ON pending_auth_sessions (target_user_id); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_expires_at_idx + ON pending_auth_sessions (expires_at); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_provider_idx + ON pending_auth_sessions (provider_type, provider_key, provider_subject); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_completion_code_idx + ON pending_auth_sessions (completion_code_hash); + +CREATE TABLE IF NOT EXISTS identity_adoption_decisions ( + id BIGSERIAL PRIMARY KEY, + pending_auth_session_id BIGINT NOT NULL REFERENCES pending_auth_sessions(id) ON DELETE CASCADE, + identity_id BIGINT NULL REFERENCES auth_identities(id) ON DELETE SET NULL, + adopt_display_name BOOLEAN NOT NULL DEFAULT FALSE, + adopt_avatar BOOLEAN NOT NULL DEFAULT FALSE, + decided_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS identity_adoption_decisions_pending_auth_session_id_key + ON identity_adoption_decisions (pending_auth_session_id); + +CREATE INDEX IF NOT EXISTS identity_adoption_decisions_identity_id_idx + ON identity_adoption_decisions (identity_id); + +CREATE TABLE IF NOT EXISTS auth_identity_migration_reports ( + id BIGSERIAL PRIMARY KEY, + report_type VARCHAR(40) NOT NULL, + report_key TEXT NOT NULL, + details JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS auth_identity_migration_reports_type_idx + ON auth_identity_migration_reports (report_type); + +CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_migration_reports_type_key + ON auth_identity_migration_reports (report_type, report_key); diff --git a/backend/migrations/109_auth_identity_compat_backfill.sql b/backend/migrations/109_auth_identity_compat_backfill.sql new file mode 100644 index 00000000..ddbbedbc --- /dev/null +++ b/backend/migrations/109_auth_identity_compat_backfill.sql @@ -0,0 +1,125 @@ +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + u.id, + 'email', + 'email', + LOWER(BTRIM(u.email)), + COALESCE(u.updated_at, u.created_at, NOW()), + jsonb_build_object( + 'backfill_source', 'users.email', + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND BTRIM(COALESCE(u.email, '')) <> '' + AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@linuxdo-connect.invalid')) <> '@linuxdo-connect.invalid' + AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@oidc-connect.invalid')) <> '@oidc-connect.invalid' + AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@wechat-connect.invalid')) <> '@wechat-connect.invalid' +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; + +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + u.id, + 'linuxdo', + 'linuxdo', + SUBSTRING(BTRIM(u.email) FROM '(?i)^linuxdo-(.+)@linuxdo-connect\.invalid$'), + COALESCE(u.updated_at, u.created_at, NOW()), + jsonb_build_object( + 'backfill_source', 'synthetic_email', + 'legacy_email', BTRIM(u.email), + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^linuxdo-.+@linuxdo-connect\.invalid$' +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; + +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + u.id, + 'wechat', + 'wechat', + SUBSTRING(BTRIM(u.email) FROM '(?i)^wechat-(.+)@wechat-connect\.invalid$'), + COALESCE(u.updated_at, u.created_at, NOW()), + jsonb_build_object( + 'backfill_source', 'synthetic_email', + 'legacy_email', BTRIM(u.email), + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$' +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; + +UPDATE users +SET signup_source = 'linuxdo' +WHERE deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^linuxdo-.+@linuxdo-connect\.invalid$'; + +UPDATE users +SET signup_source = 'wechat' +WHERE deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^wechat-.+@wechat-connect\.invalid$'; + +UPDATE users +SET signup_source = 'oidc' +WHERE deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^oidc-.+@oidc-connect\.invalid$'; + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'oidc_synthetic_email_requires_manual_recovery', + CAST(u.id AS TEXT), + jsonb_build_object( + 'user_id', u.id, + 'email', LOWER(BTRIM(u.email)), + 'reason', 'cannot recover issuer_plus_sub deterministically from synthetic email alone', + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^oidc-.+@oidc-connect\.invalid$' +ON CONFLICT (report_type, report_key) DO NOTHING; + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_openid_only_requires_remediation', + CAST(u.id AS TEXT), + jsonb_build_object( + 'user_id', u.id, + 'email', LOWER(BTRIM(u.email)), + 'reason', 'legacy wechat synthetic identity requires explicit unionid remediation if channel-only data exists', + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$' + AND NOT EXISTS ( + SELECT 1 + FROM auth_identities ai + WHERE ai.user_id = u.id + AND ai.provider_type = 'wechat' + AND ai.provider_key = 'wechat' + ) +ON CONFLICT (report_type, report_key) DO NOTHING; diff --git a/backend/migrations/110_pending_auth_and_provider_default_grants.sql b/backend/migrations/110_pending_auth_and_provider_default_grants.sql new file mode 100644 index 00000000..fbaed62e --- /dev/null +++ b/backend/migrations/110_pending_auth_and_provider_default_grants.sql @@ -0,0 +1,60 @@ +CREATE TABLE IF NOT EXISTS user_provider_default_grants ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + provider_type VARCHAR(20) NOT NULL, + grant_reason VARCHAR(20) NOT NULL DEFAULT 'first_bind', + granted_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT user_provider_default_grants_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')), + CONSTRAINT user_provider_default_grants_reason_check + CHECK (grant_reason IN ('signup', 'first_bind')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS user_provider_default_grants_user_provider_reason_key + ON user_provider_default_grants (user_id, provider_type, grant_reason); + +CREATE INDEX IF NOT EXISTS user_provider_default_grants_user_id_idx + ON user_provider_default_grants (user_id); + +CREATE TABLE IF NOT EXISTS user_avatars ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + storage_provider VARCHAR(20) NOT NULL DEFAULT 'database', + storage_key TEXT NOT NULL DEFAULT '', + url TEXT NOT NULL DEFAULT '', + content_type VARCHAR(100) NOT NULL DEFAULT '', + byte_size INT NOT NULL DEFAULT 0, + sha256 VARCHAR(64) NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS user_avatars_user_id_key + ON user_avatars (user_id); + +INSERT INTO settings (key, value) +VALUES + ('auth_source_default_email_balance', '0'), + ('auth_source_default_email_concurrency', '5'), + ('auth_source_default_email_subscriptions', '[]'), + ('auth_source_default_email_grant_on_signup', 'true'), + ('auth_source_default_email_grant_on_first_bind', 'false'), + ('auth_source_default_linuxdo_balance', '0'), + ('auth_source_default_linuxdo_concurrency', '5'), + ('auth_source_default_linuxdo_subscriptions', '[]'), + ('auth_source_default_linuxdo_grant_on_signup', 'true'), + ('auth_source_default_linuxdo_grant_on_first_bind', 'false'), + ('auth_source_default_oidc_balance', '0'), + ('auth_source_default_oidc_concurrency', '5'), + ('auth_source_default_oidc_subscriptions', '[]'), + ('auth_source_default_oidc_grant_on_signup', 'true'), + ('auth_source_default_oidc_grant_on_first_bind', 'false'), + ('auth_source_default_wechat_balance', '0'), + ('auth_source_default_wechat_concurrency', '5'), + ('auth_source_default_wechat_subscriptions', '[]'), + ('auth_source_default_wechat_grant_on_signup', 'true'), + ('auth_source_default_wechat_grant_on_first_bind', 'false'), + ('force_email_on_third_party_signup', 'false') +ON CONFLICT (key) DO NOTHING; + diff --git a/backend/migrations/111_payment_routing_and_scheduler_flags.sql b/backend/migrations/111_payment_routing_and_scheduler_flags.sql new file mode 100644 index 00000000..f222a8d4 --- /dev/null +++ b/backend/migrations/111_payment_routing_and_scheduler_flags.sql @@ -0,0 +1,8 @@ +INSERT INTO settings (key, value) +VALUES + ('payment_visible_method_alipay_source', ''), + ('payment_visible_method_wxpay_source', ''), + ('payment_visible_method_alipay_enabled', 'false'), + ('payment_visible_method_wxpay_enabled', 'false'), + ('openai_advanced_scheduler_enabled', 'false') +ON CONFLICT (key) DO NOTHING; diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts new file mode 100644 index 00000000..574e1e36 --- /dev/null +++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts @@ -0,0 +1,60 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const post = vi.fn() + +vi.mock('@/api/client', () => ({ + apiClient: { + post + } +})) + +describe('oauth adoption auth api', () => { + beforeEach(() => { + post.mockReset() + post.mockResolvedValue({ data: {} }) + }) + + it('posts adoption decisions when exchanging pending oauth completion', async () => { + const { exchangePendingOAuthCompletion } = await import('@/api/auth') + + await exchangePendingOAuthCompletion({ + adoptDisplayName: false, + adoptAvatar: true + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', { + adopt_display_name: false, + adopt_avatar: true + }) + }) + + it('posts linuxdo invitation completion with adoption decisions', async () => { + const { completeLinuxDoOAuthRegistration } = await import('@/api/auth') + + await completeLinuxDoOAuthRegistration('invite-code', { + adoptDisplayName: true, + adoptAvatar: false + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: true, + adopt_avatar: false + }) + }) + + it('posts oidc invitation completion with adoption decisions', async () => { + const { completeOIDCOAuthRegistration } = await import('@/api/auth') + + await completeOIDCOAuthRegistration('invite-code', { + adoptDisplayName: false, + adoptAvatar: true + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: false, + adopt_avatar: true + }) + }) +}) diff --git a/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts new file mode 100644 index 00000000..8756146e --- /dev/null +++ b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts @@ -0,0 +1,118 @@ +import { describe, expect, it } from 'vitest' + +import { + appendAuthSourceDefaultsToUpdateRequest, + buildAuthSourceDefaultsState, + type UpdateSettingsRequest, +} from '@/api/admin/settings' + +describe('admin settings auth source defaults helpers', () => { + it('builds auth source defaults state from flat settings fields', () => { + const state = buildAuthSourceDefaultsState({ + auth_source_default_email_balance: 9.5, + auth_source_default_email_concurrency: 3, + auth_source_default_email_subscriptions: [ + { group_id: 1, validity_days: 30 }, + ], + auth_source_default_email_grant_on_signup: false, + auth_source_default_email_grant_on_first_bind: true, + auth_source_default_linuxdo_balance: 6, + auth_source_default_linuxdo_concurrency: 8, + auth_source_default_linuxdo_subscriptions: [ + { group_id: 2, validity_days: 60 }, + ], + auth_source_default_linuxdo_grant_on_signup: true, + auth_source_default_linuxdo_grant_on_first_bind: false, + }) + + expect(state.email).toEqual({ + balance: 9.5, + concurrency: 3, + subscriptions: [{ group_id: 1, validity_days: 30 }], + grant_on_signup: false, + grant_on_first_bind: true, + }) + expect(state.linuxdo).toEqual({ + balance: 6, + concurrency: 8, + subscriptions: [{ group_id: 2, validity_days: 60 }], + grant_on_signup: true, + grant_on_first_bind: false, + }) + expect(state.oidc).toEqual({ + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: true, + grant_on_first_bind: false, + }) + expect(state.wechat).toEqual({ + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: true, + grant_on_first_bind: false, + }) + }) + + it('appends auth source defaults back onto update payload', () => { + const payload: UpdateSettingsRequest = { + site_name: 'Sub2API', + } + + appendAuthSourceDefaultsToUpdateRequest(payload, { + email: { + balance: 1.25, + concurrency: 2, + subscriptions: [{ group_id: 3, validity_days: 7 }], + grant_on_signup: true, + grant_on_first_bind: false, + }, + linuxdo: { + balance: 0, + concurrency: 6, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: true, + }, + oidc: { + balance: 4, + concurrency: 9, + subscriptions: [{ group_id: 9, validity_days: 90 }], + grant_on_signup: true, + grant_on_first_bind: true, + }, + wechat: { + balance: 2, + concurrency: 5, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: false, + }, + }) + + expect(payload).toMatchObject({ + site_name: 'Sub2API', + auth_source_default_email_balance: 1.25, + auth_source_default_email_concurrency: 2, + auth_source_default_email_subscriptions: [{ group_id: 3, validity_days: 7 }], + auth_source_default_email_grant_on_signup: true, + auth_source_default_email_grant_on_first_bind: false, + auth_source_default_linuxdo_balance: 0, + auth_source_default_linuxdo_concurrency: 6, + auth_source_default_linuxdo_subscriptions: [], + auth_source_default_linuxdo_grant_on_signup: false, + auth_source_default_linuxdo_grant_on_first_bind: true, + auth_source_default_oidc_balance: 4, + auth_source_default_oidc_concurrency: 9, + auth_source_default_oidc_subscriptions: [{ group_id: 9, validity_days: 90 }], + auth_source_default_oidc_grant_on_signup: true, + auth_source_default_oidc_grant_on_first_bind: true, + auth_source_default_wechat_balance: 2, + auth_source_default_wechat_concurrency: 5, + auth_source_default_wechat_subscriptions: [], + auth_source_default_wechat_grant_on_signup: false, + auth_source_default_wechat_grant_on_first_bind: false, + }) + }) +}) diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 1e4a3053..8e182c1c 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -11,6 +11,81 @@ export interface DefaultSubscriptionSetting { validity_days: number } +export type AuthSourceType = 'email' | 'linuxdo' | 'oidc' | 'wechat' + +export interface AuthSourceDefaultsValue { + balance: number + concurrency: number + subscriptions: DefaultSubscriptionSetting[] + grant_on_signup: boolean + grant_on_first_bind: boolean +} + +export type AuthSourceDefaultsState = Record + +const AUTH_SOURCE_TYPES: AuthSourceType[] = ['email', 'linuxdo', 'oidc', 'wechat'] +const AUTH_SOURCE_DEFAULT_BALANCE = 0 +const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5 + +export function normalizeDefaultSubscriptionSettings( + subscriptions: DefaultSubscriptionSetting[] | null | undefined +): DefaultSubscriptionSetting[] { + if (!Array.isArray(subscriptions)) return [] + + return subscriptions + .filter((item) => item.group_id > 0 && item.validity_days > 0) + .map((item) => ({ + group_id: Math.floor(item.group_id), + validity_days: Math.min(36500, Math.max(1, Math.floor(item.validity_days))) + })) +} + +export function buildAuthSourceDefaultsState( + settings: Partial +): AuthSourceDefaultsState { + const raw = settings as Record + + return AUTH_SOURCE_TYPES.reduce((acc, source) => { + const subscriptions = raw[`auth_source_default_${source}_subscriptions`] + acc[source] = { + balance: Number(raw[`auth_source_default_${source}_balance`] ?? AUTH_SOURCE_DEFAULT_BALANCE), + concurrency: Math.max( + 1, + Number(raw[`auth_source_default_${source}_concurrency`] ?? AUTH_SOURCE_DEFAULT_CONCURRENCY) + ), + subscriptions: normalizeDefaultSubscriptionSettings( + Array.isArray(subscriptions) ? (subscriptions as DefaultSubscriptionSetting[]) : [] + ), + grant_on_signup: raw[`auth_source_default_${source}_grant_on_signup`] !== false, + grant_on_first_bind: raw[`auth_source_default_${source}_grant_on_first_bind`] === true, + } + return acc + }, {} as AuthSourceDefaultsState) +} + +export function appendAuthSourceDefaultsToUpdateRequest( + payload: UpdateSettingsRequest, + authSourceDefaults: AuthSourceDefaultsState +): UpdateSettingsRequest { + const target = payload as Record + + for (const source of AUTH_SOURCE_TYPES) { + const current = authSourceDefaults[source] + target[`auth_source_default_${source}_balance`] = Number(current.balance) || 0 + target[`auth_source_default_${source}_concurrency`] = Math.max( + 1, + Math.floor(Number(current.concurrency) || AUTH_SOURCE_DEFAULT_CONCURRENCY) + ) + target[`auth_source_default_${source}_subscriptions`] = normalizeDefaultSubscriptionSettings( + current.subscriptions + ) + target[`auth_source_default_${source}_grant_on_signup`] = current.grant_on_signup + target[`auth_source_default_${source}_grant_on_first_bind`] = current.grant_on_first_bind + } + + return payload +} + /** * System settings interface */ @@ -29,6 +104,27 @@ export interface SystemSettings { default_balance: number default_concurrency: number default_subscriptions: DefaultSubscriptionSetting[] + auth_source_default_email_balance?: number + auth_source_default_email_concurrency?: number + auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_email_grant_on_signup?: boolean + auth_source_default_email_grant_on_first_bind?: boolean + auth_source_default_linuxdo_balance?: number + auth_source_default_linuxdo_concurrency?: number + auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_linuxdo_grant_on_signup?: boolean + auth_source_default_linuxdo_grant_on_first_bind?: boolean + auth_source_default_oidc_balance?: number + auth_source_default_oidc_concurrency?: number + auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_oidc_grant_on_signup?: boolean + auth_source_default_oidc_grant_on_first_bind?: boolean + auth_source_default_wechat_balance?: number + auth_source_default_wechat_concurrency?: number + auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_wechat_grant_on_signup?: boolean + auth_source_default_wechat_grant_on_first_bind?: boolean + force_email_on_third_party_signup?: boolean // OEM settings site_name: string site_logo: string @@ -137,6 +233,11 @@ export interface SystemSettings { payment_cancel_rate_limit_window: number payment_cancel_rate_limit_unit: string payment_cancel_rate_limit_window_mode: string + payment_visible_method_alipay_source?: string + payment_visible_method_wxpay_source?: string + payment_visible_method_alipay_enabled?: boolean + payment_visible_method_wxpay_enabled?: boolean + openai_advanced_scheduler_enabled?: boolean // Balance & quota notification balance_low_notify_enabled: boolean @@ -158,6 +259,27 @@ export interface UpdateSettingsRequest { default_balance?: number default_concurrency?: number default_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_email_balance?: number + auth_source_default_email_concurrency?: number + auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_email_grant_on_signup?: boolean + auth_source_default_email_grant_on_first_bind?: boolean + auth_source_default_linuxdo_balance?: number + auth_source_default_linuxdo_concurrency?: number + auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_linuxdo_grant_on_signup?: boolean + auth_source_default_linuxdo_grant_on_first_bind?: boolean + auth_source_default_oidc_balance?: number + auth_source_default_oidc_concurrency?: number + auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_oidc_grant_on_signup?: boolean + auth_source_default_oidc_grant_on_first_bind?: boolean + auth_source_default_wechat_balance?: number + auth_source_default_wechat_concurrency?: number + auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_wechat_grant_on_signup?: boolean + auth_source_default_wechat_grant_on_first_bind?: boolean + force_email_on_third_party_signup?: boolean site_name?: string site_logo?: string site_subtitle?: string @@ -245,6 +367,11 @@ export interface UpdateSettingsRequest { payment_cancel_rate_limit_window?: number payment_cancel_rate_limit_unit?: string payment_cancel_rate_limit_window_mode?: string + payment_visible_method_alipay_source?: string + payment_visible_method_wxpay_source?: string + payment_visible_method_alipay_enabled?: boolean + payment_visible_method_wxpay_enabled?: boolean + openai_advanced_scheduler_enabled?: boolean // Balance & quota notification balance_low_notify_enabled?: boolean balance_low_notify_threshold?: number diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index d7abcd6a..10b6ca58 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -198,6 +198,26 @@ export interface PendingOAuthExchangeResponse { suggested_avatar_url?: string } +export interface OAuthAdoptionDecision { + adoptDisplayName?: boolean + adoptAvatar?: boolean +} + +function serializeOAuthAdoptionDecision( + decision?: OAuthAdoptionDecision +): Record { + const payload: Record = {} + + if (typeof decision?.adoptDisplayName === 'boolean') { + payload.adopt_display_name = decision.adoptDisplayName + } + if (typeof decision?.adoptAvatar === 'boolean') { + payload.adopt_avatar = decision.adoptAvatar + } + + return payload +} + /** * Refresh the access token using the refresh token * @returns New token pair @@ -353,7 +373,8 @@ export async function resetPassword(request: ResetPasswordRequest): Promise { const { data } = await apiClient.post<{ access_token: string @@ -361,7 +382,8 @@ export async function completeLinuxDoOAuthRegistration( expires_in: number token_type: string }>('/auth/oauth/linuxdo/complete-registration', { - invitation_code: invitationCode + invitation_code: invitationCode, + ...serializeOAuthAdoptionDecision(decision) }) return data } @@ -372,7 +394,8 @@ export async function completeLinuxDoOAuthRegistration( * @returns Token pair on success */ export async function completeOIDCOAuthRegistration( - invitationCode: string + invitationCode: string, + decision?: OAuthAdoptionDecision ): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> { const { data } = await apiClient.post<{ access_token: string @@ -380,13 +403,19 @@ export async function completeOIDCOAuthRegistration( expires_in: number token_type: string }>('/auth/oauth/oidc/complete-registration', { - invitation_code: invitationCode + invitation_code: invitationCode, + ...serializeOAuthAdoptionDecision(decision) }) return data } -export async function exchangePendingOAuthCompletion(): Promise { - const { data } = await apiClient.post('/auth/oauth/pending/exchange', {}) +export async function exchangePendingOAuthCompletion( + decision?: OAuthAdoptionDecision +): Promise { + const { data } = await apiClient.post( + '/auth/oauth/pending/exchange', + serializeOAuthAdoptionDecision(decision) + ) return data } diff --git a/frontend/src/components/auth/WechatOAuthSection.vue b/frontend/src/components/auth/WechatOAuthSection.vue new file mode 100644 index 00000000..94e20222 --- /dev/null +++ b/frontend/src/components/auth/WechatOAuthSection.vue @@ -0,0 +1,53 @@ + + + diff --git a/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts b/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts new file mode 100644 index 00000000..810832a0 --- /dev/null +++ b/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts @@ -0,0 +1,74 @@ +import { mount } from '@vue/test-utils' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import WechatOAuthSection from '@/components/auth/WechatOAuthSection.vue' + +const routeState = vi.hoisted(() => ({ + query: {} as Record, +})) + +const locationState = vi.hoisted(() => ({ + current: { href: 'http://localhost/login' } as { href: string }, +})) + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, +})) + +vi.mock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'auth.oidc.signIn') { + return `Continue with ${params?.providerName ?? ''}`.trim() + } + if (key === 'auth.oauthOrContinue') { + return 'or continue' + } + return key + }, + }), +})) + +describe('WechatOAuthSection', () => { + beforeEach(() => { + routeState.query = { redirect: '/billing?plan=pro' } + locationState.current = { href: 'http://localhost/login' } + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState.current, + }) + Object.defineProperty(window.navigator, 'userAgent', { + configurable: true, + value: 'Mozilla/5.0', + }) + }) + + afterEach(() => { + vi.unstubAllGlobals() + }) + + it('starts the open WeChat OAuth flow with the current redirect target', async () => { + const wrapper = mount(WechatOAuthSection) + + expect(wrapper.text()).toContain('WeChat') + + await wrapper.get('button').trigger('click') + + expect(locationState.current.href).toContain( + '/api/v1/auth/oauth/wechat/start?mode=open&redirect=%2Fbilling%3Fplan%3Dpro' + ) + }) + + it('uses mp mode inside the WeChat browser', async () => { + Object.defineProperty(window.navigator, 'userAgent', { + configurable: true, + value: 'Mozilla/5.0 MicroMessenger', + }) + const wrapper = mount(WechatOAuthSection) + + await wrapper.get('button').trigger('click') + + expect(locationState.current.href).toContain( + '/api/v1/auth/oauth/wechat/start?mode=mp&redirect=%2Fbilling%3Fplan%3Dpro' + ) + }) +}) diff --git a/frontend/src/components/payment/PaymentStatusPanel.vue b/frontend/src/components/payment/PaymentStatusPanel.vue index 974dee66..8f5a5666 100644 --- a/frontend/src/components/payment/PaymentStatusPanel.vue +++ b/frontend/src/components/payment/PaymentStatusPanel.vue @@ -141,7 +141,9 @@ const props = defineProps<{ orderType?: string }>() -const emit = defineEmits<{ done: []; success: [] }>() +type PaymentOutcome = 'success' | 'cancelled' | 'expired' + +const emit = defineEmits<{ done: []; success: []; settled: [outcome: PaymentOutcome] }>() const { t } = useI18n() const paymentStore = usePaymentStore() @@ -154,7 +156,7 @@ const cancelling = ref(false) const paidOrder = ref(null) // Terminal outcome: null = still active, 'success' | 'cancelled' | 'expired' -const outcome = ref<'success' | 'cancelled' | 'expired' | null>(null) +const outcome = ref(null) let pollTimer: ReturnType | null = null let countdownTimer: ReturnType | null = null @@ -194,10 +196,19 @@ const countdownDisplay = computed(() => { function reopenPopup() { if (props.payUrl) { - window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES) + const win = window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES) + if (!win || win.closed) { + window.location.href = props.payUrl + } } } +function setOutcome(next: PaymentOutcome) { + if (outcome.value === next) return + outcome.value = next + emit('settled', next) +} + async function renderQR() { await nextTick() if (!qrCanvas.value || !qrUrl.value) return @@ -214,23 +225,23 @@ async function pollStatus() { if (order.status === 'COMPLETED' || order.status === 'PAID') { cleanup() paidOrder.value = order - outcome.value = 'success' + setOutcome('success') emit('success') } else if (order.status === 'CANCELLED') { cleanup() - outcome.value = 'cancelled' + setOutcome('cancelled') } else if (order.status === 'EXPIRED' || order.status === 'FAILED') { cleanup() - outcome.value = 'expired' + setOutcome('expired') } } function startCountdown(seconds: number) { remainingSeconds.value = Math.max(0, seconds) - if (remainingSeconds.value <= 0) { outcome.value = 'expired'; return } + if (remainingSeconds.value <= 0) { setOutcome('expired'); return } countdownTimer = setInterval(() => { remainingSeconds.value-- - if (remainingSeconds.value <= 0) { outcome.value = 'expired'; cleanup() } + if (remainingSeconds.value <= 0) { setOutcome('expired'); cleanup() } }, 1000) } @@ -240,7 +251,7 @@ async function handleCancel() { try { await paymentAPI.cancelOrder(props.orderId) cleanup() - outcome.value = 'cancelled' + setOutcome('cancelled') } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) } finally { diff --git a/frontend/src/components/payment/__tests__/paymentFlow.spec.ts b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts new file mode 100644 index 00000000..f5212f15 --- /dev/null +++ b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts @@ -0,0 +1,163 @@ +import { describe, expect, it } from 'vitest' +import type { CreateOrderResult, MethodLimit } from '@/types/payment' +import { + decidePaymentLaunch, + getVisibleMethods, + readPaymentRecoverySnapshot, + type PaymentRecoverySnapshot, +} from '@/components/payment/paymentFlow' + +function methodLimit(overrides: Partial = {}): MethodLimit { + return { + daily_limit: 0, + daily_used: 0, + daily_remaining: 0, + single_min: 0, + single_max: 0, + fee_rate: 0, + available: true, + ...overrides, + } +} + +function createOrderResult(overrides: Partial = {}): CreateOrderResult { + return { + order_id: 101, + amount: 88, + pay_amount: 88, + fee_rate: 0, + expires_at: '2099-01-01T00:10:00.000Z', + ...overrides, + } +} + +describe('getVisibleMethods', () => { + it('filters hidden provider methods and normalizes aliases', () => { + const visible = getVisibleMethods({ + alipay_direct: methodLimit({ single_min: 5 }), + wxpay: methodLimit({ single_max: 100 }), + stripe: methodLimit({ fee_rate: 3 }), + }) + + expect(visible).toEqual({ + alipay: methodLimit({ single_min: 5 }), + wxpay: methodLimit({ single_max: 100 }), + }) + }) + + it('prefers canonical visible methods over aliases when both exist', () => { + const visible = getVisibleMethods({ + alipay: methodLimit({ single_min: 2 }), + alipay_direct: methodLimit({ single_min: 9 }), + wxpay_direct: methodLimit({ fee_rate: 1.2 }), + }) + + expect(visible.alipay.single_min).toBe(2) + expect(visible.wxpay.fee_rate).toBe(1.2) + }) +}) + +describe('decidePaymentLaunch', () => { + it('uses Stripe popup waiting flow for desktop Alipay client secret', () => { + const decision = decidePaymentLaunch(createOrderResult({ + client_secret: 'cs_test', + resume_token: 'resume-1', + }), { + visibleMethod: 'alipay', + orderType: 'balance', + isMobile: false, + }) + + expect(decision.kind).toBe('stripe_popup') + expect(decision.paymentState.paymentType).toBe('alipay') + expect(decision.stripeMethod).toBe('alipay') + expect(decision.recovery.resumeToken).toBe('resume-1') + }) + + it('uses Stripe route flow for mobile WeChat client secret', () => { + const decision = decidePaymentLaunch(createOrderResult({ + client_secret: 'cs_test', + }), { + visibleMethod: 'wxpay', + orderType: 'subscription', + isMobile: true, + }) + + expect(decision.kind).toBe('stripe_route') + expect(decision.stripeMethod).toBe('wechat_pay') + expect(decision.paymentState.orderType).toBe('subscription') + }) + + it('keeps hosted redirect metadata for recovery flows', () => { + const decision = decidePaymentLaunch(createOrderResult({ + pay_url: 'https://pay.example.com/session/abc', + payment_mode: 'popup', + resume_token: 'resume-2', + }), { + visibleMethod: 'wxpay', + orderType: 'balance', + isMobile: false, + }) + + expect(decision.kind).toBe('redirect_waiting') + expect(decision.paymentState.payUrl).toBe('https://pay.example.com/session/abc') + expect(decision.recovery.paymentMode).toBe('popup') + expect(decision.recovery.resumeToken).toBe('resume-2') + }) +}) + +describe('readPaymentRecoverySnapshot', () => { + it('restores an unexpired snapshot when the resume token matches', () => { + const snapshot: PaymentRecoverySnapshot = { + orderId: 33, + amount: 18, + qrCode: '', + expiresAt: '2099-01-01T00:10:00.000Z', + paymentType: 'alipay', + payUrl: 'https://pay.example.com/session/33', + clientSecret: '', + payAmount: 18, + orderType: 'balance', + paymentMode: 'popup', + resumeToken: 'resume-33', + createdAt: Date.UTC(2099, 0, 1, 0, 0, 0), + } + + const restored = readPaymentRecoverySnapshot(JSON.stringify(snapshot), { + now: Date.UTC(2099, 0, 1, 0, 1, 0), + resumeToken: 'resume-33', + }) + + expect(restored?.orderId).toBe(33) + }) + + it('drops expired or mismatched recovery snapshots', () => { + const expiredSnapshot: PaymentRecoverySnapshot = { + orderId: 55, + amount: 18, + qrCode: '', + expiresAt: '2024-01-01T00:10:00.000Z', + paymentType: 'wxpay', + payUrl: 'https://pay.example.com/session/55', + clientSecret: '', + payAmount: 18, + orderType: 'balance', + paymentMode: 'popup', + resumeToken: 'resume-55', + createdAt: Date.UTC(2024, 0, 1, 0, 0, 0), + } + + expect(readPaymentRecoverySnapshot(JSON.stringify(expiredSnapshot), { + now: Date.UTC(2024, 0, 1, 0, 20, 0), + resumeToken: 'resume-55', + })).toBeNull() + + expect(readPaymentRecoverySnapshot(JSON.stringify({ + ...expiredSnapshot, + expiresAt: '2099-01-01T00:10:00.000Z', + }), { + now: Date.UTC(2099, 0, 1, 0, 1, 0), + resumeToken: 'other-token', + })).toBeNull() + }) +}) diff --git a/frontend/src/components/payment/paymentFlow.ts b/frontend/src/components/payment/paymentFlow.ts new file mode 100644 index 00000000..70225a0c --- /dev/null +++ b/frontend/src/components/payment/paymentFlow.ts @@ -0,0 +1,197 @@ +import type { CreateOrderResult, MethodLimit, OrderType } from '@/types/payment' + +export const PAYMENT_RECOVERY_STORAGE_KEY = 'payment.recovery.current' + +const VISIBLE_METHOD_ALIASES = { + alipay: 'alipay', + alipay_direct: 'alipay', + wxpay: 'wxpay', + wxpay_direct: 'wxpay', +} as const + +export type VisiblePaymentMethod = 'alipay' | 'wxpay' +export type StripeVisibleMethod = 'alipay' | 'wechat_pay' +export type PaymentLaunchKind = + | 'qr_waiting' + | 'redirect_waiting' + | 'stripe_popup' + | 'stripe_route' + | 'unhandled' + +export interface PaymentRecoverySnapshot { + orderId: number + amount: number + qrCode: string + expiresAt: string + paymentType: string + payUrl: string + clientSecret: string + payAmount: number + orderType: OrderType | '' + paymentMode: string + resumeToken: string + createdAt: number +} + +export interface PaymentLaunchContext { + visibleMethod: string + orderType: OrderType + isMobile: boolean + now?: number + stripePopupUrl?: string + stripeRouteUrl?: string +} + +export interface PaymentLaunchDecision { + kind: PaymentLaunchKind + paymentState: PaymentRecoverySnapshot + recovery: PaymentRecoverySnapshot + stripeMethod?: StripeVisibleMethod +} + +type CreateOrderFlowResult = CreateOrderResult & { + resume_token?: string +} + +type StorageWriter = Pick + +export function normalizeVisibleMethod(method: string): VisiblePaymentMethod | '' { + const normalized = VISIBLE_METHOD_ALIASES[method.trim() as keyof typeof VISIBLE_METHOD_ALIASES] + return normalized ?? '' +} + +export function getVisibleMethods(methods: Record): Record { + const visible: Record = {} + + Object.entries(methods).forEach(([type, limit]) => { + const normalized = normalizeVisibleMethod(type) + if (!normalized) return + + const isCanonical = type === normalized + const existing = visible[normalized] + if (!existing || isCanonical) { + visible[normalized] = { ...limit } + } + }) + + return visible +} + +export function decidePaymentLaunch( + result: CreateOrderFlowResult, + context: PaymentLaunchContext, +): PaymentLaunchDecision { + const visibleMethod = normalizeVisibleMethod(context.visibleMethod) || context.visibleMethod + const baseState = createPaymentRecoverySnapshot({ + orderId: result.order_id, + amount: result.amount, + qrCode: result.qr_code || '', + expiresAt: result.expires_at || '', + paymentType: visibleMethod, + payUrl: result.pay_url || '', + clientSecret: result.client_secret || '', + payAmount: result.pay_amount, + orderType: context.orderType, + paymentMode: (result.payment_mode || '').trim(), + resumeToken: result.resume_token || '', + }, context.now) + + if (baseState.clientSecret) { + const stripeMethod: StripeVisibleMethod = visibleMethod === 'wxpay' ? 'wechat_pay' : 'alipay' + const kind: PaymentLaunchKind = stripeMethod === 'alipay' && !context.isMobile + ? 'stripe_popup' + : 'stripe_route' + const payUrl = kind === 'stripe_popup' + ? context.stripePopupUrl || context.stripeRouteUrl || '' + : context.stripeRouteUrl || context.stripePopupUrl || '' + const paymentState = { ...baseState, payUrl } + return { kind, paymentState, recovery: paymentState, stripeMethod } + } + + if (baseState.qrCode) { + return { kind: 'qr_waiting', paymentState: baseState, recovery: baseState } + } + + if (baseState.payUrl) { + return { kind: 'redirect_waiting', paymentState: baseState, recovery: baseState } + } + + return { kind: 'unhandled', paymentState: baseState, recovery: baseState } +} + +export function createPaymentRecoverySnapshot( + state: Omit, + now = Date.now(), +): PaymentRecoverySnapshot { + return { + ...state, + createdAt: now, + } +} + +export function writePaymentRecoverySnapshot( + storage: StorageWriter, + snapshot: PaymentRecoverySnapshot, + key = PAYMENT_RECOVERY_STORAGE_KEY, +): void { + storage.setItem(key, JSON.stringify(snapshot)) +} + +export function clearPaymentRecoverySnapshot( + storage: Pick, + key = PAYMENT_RECOVERY_STORAGE_KEY, +): void { + storage.removeItem(key) +} + +export function readPaymentRecoverySnapshot( + raw: string | null | undefined, + options: { now?: number; resumeToken?: string } = {}, +): PaymentRecoverySnapshot | null { + if (!raw) return null + + try { + const parsed = JSON.parse(raw) as Partial + if ( + typeof parsed.orderId !== 'number' + || typeof parsed.amount !== 'number' + || typeof parsed.qrCode !== 'string' + || typeof parsed.expiresAt !== 'string' + || typeof parsed.paymentType !== 'string' + || typeof parsed.payUrl !== 'string' + || typeof parsed.clientSecret !== 'string' + || typeof parsed.payAmount !== 'number' + || typeof parsed.paymentMode !== 'string' + || typeof parsed.resumeToken !== 'string' + || typeof parsed.createdAt !== 'number' + ) { + return null + } + + const now = options.now ?? Date.now() + const expiresAt = Date.parse(parsed.expiresAt) + if (Number.isFinite(expiresAt) && expiresAt <= now) { + return null + } + if (options.resumeToken && parsed.resumeToken && parsed.resumeToken !== options.resumeToken) { + return null + } + + return { + orderId: parsed.orderId, + amount: parsed.amount, + qrCode: parsed.qrCode, + expiresAt: parsed.expiresAt, + paymentType: parsed.paymentType, + payUrl: parsed.payUrl, + clientSecret: parsed.clientSecret, + payAmount: parsed.payAmount, + orderType: parsed.orderType === 'subscription' ? 'subscription' : 'balance', + paymentMode: parsed.paymentMode, + resumeToken: parsed.resumeToken, + createdAt: parsed.createdAt, + } + } catch { + return null + } +} diff --git a/frontend/src/router/__tests__/wechat-route.spec.ts b/frontend/src/router/__tests__/wechat-route.spec.ts new file mode 100644 index 00000000..84b20452 --- /dev/null +++ b/frontend/src/router/__tests__/wechat-route.spec.ts @@ -0,0 +1,55 @@ +import { describe, expect, it, vi } from 'vitest' + +const authStore = vi.hoisted(() => ({ + checkAuth: vi.fn(), + isAuthenticated: false, + isAdmin: false, + isSimpleMode: false, +})) + +const appStore = vi.hoisted(() => ({ + siteName: 'Sub2API', + backendModeEnabled: false, + cachedPublicSettings: null as null | Record, +})) + +vi.mock('@/stores/auth', () => ({ + useAuthStore: () => authStore, +})) + +vi.mock('@/stores/app', () => ({ + useAppStore: () => appStore, +})) + +vi.mock('@/stores/adminSettings', () => ({ + useAdminSettingsStore: () => ({ + customMenuItems: [], + }), +})) + +vi.mock('@/composables/useNavigationLoading', () => ({ + useNavigationLoadingState: () => ({ + startNavigation: vi.fn(), + endNavigation: vi.fn(), + isLoading: { value: false }, + }), +})) + +vi.mock('@/composables/useRoutePrefetch', () => ({ + useRoutePrefetch: () => ({ + triggerPrefetch: vi.fn(), + cancelPendingPrefetch: vi.fn(), + resetPrefetchState: vi.fn(), + }), +})) + +describe('router WeChat OAuth route', () => { + it('registers the WeChat callback route as a public route', async () => { + const { default: router } = await import('@/router') + const route = router.getRoutes().find((record) => record.name === 'WeChatOAuthCallback') + + expect(route?.path).toBe('/auth/wechat/callback') + expect(route?.meta.requiresAuth).toBe(false) + expect(route?.meta.title).toBe('WeChat OAuth Callback') + }) +}) diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index ad6e71c4..beaa1da2 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -83,6 +83,15 @@ const routes: RouteRecordRaw[] = [ title: 'LinuxDo OAuth Callback' } }, + { + path: '/auth/wechat/callback', + name: 'WeChatOAuthCallback', + component: () => import('@/views/auth/WechatCallbackView.vue'), + meta: { + requiresAuth: false, + title: 'WeChat OAuth Callback' + } + }, { path: '/auth/oidc/callback', name: 'OIDCOAuthCallback', diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts index 1995383d..1b1af87b 100644 --- a/frontend/src/stores/app.ts +++ b/frontend/src/stores/app.ts @@ -336,6 +336,7 @@ export const useAppStore = defineStore('app', () => { custom_menu_items: [], custom_endpoints: [], linuxdo_oauth_enabled: false, + wechat_oauth_enabled: false, oidc_oauth_enabled: false, oidc_oauth_provider_name: 'OIDC', backend_mode_enabled: false, diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 89fd777f..529eff55 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -123,6 +123,7 @@ export interface PublicSettings { custom_menu_items: CustomMenuItem[] custom_endpoints: CustomEndpoint[] linuxdo_oauth_enabled: boolean + wechat_oauth_enabled: boolean oidc_oauth_enabled: boolean oidc_oauth_provider_name: string backend_mode_enabled: boolean diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 8bfa0f2b..0d23baa5 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -1586,6 +1586,221 @@
+ +
+
+

+ {{ localText('认证来源默认值', 'Auth Source Defaults') }} +

+

+ {{ + localText( + '按注册来源配置新用户默认余额、并发、订阅与授权策略。', + 'Configure per-source default balance, concurrency, subscriptions, and grant rules.' + ) + }} +

+
+
+
+
+ +

+ {{ + localText( + '启用后,Linux DO、OIDC、微信注册缺少邮箱时必须先补充邮箱地址。', + 'When enabled, Linux DO, OIDC, and WeChat signups must provide an email before account creation.' + ) + }} +

+
+ +
+ +
+
+
+
{{ authSource.title }}
+

+ {{ authSource.description }} +

+
+ +
+
+ + +
+
+ + +
+
+ +
+
+
+ +

+ {{ + localText( + '来源首次注册成功后立即发放默认权益。', + 'Grant default entitlements immediately after signup.' + ) + }} +

+
+ +
+ +
+
+ +

+ {{ + localText( + '来源首次绑定到现有账号时发放默认权益。', + 'Grant default entitlements when the source is first bound to an existing user.' + ) + }} +

+
+ +
+
+ +
+
+
+ +

+ {{ + localText( + '仅对当前认证来源生效,未配置时不追加来源专属订阅。', + 'Applies only to this auth source. Leave empty to skip source-specific subscriptions.' + ) + }} +

+
+ +
+ +
+ {{ + localText( + '当前来源未配置专属默认订阅。', + 'No source-specific default subscriptions configured.' + ) + }} +
+ +
+
+
+ + +
+
+ + +
+
+ +
+
+
+
+
+
+
+
@@ -1643,19 +1858,38 @@

-
-
-
@@ -2450,6 +2684,59 @@

+
+
+
+
+ +

+ {{ + localText( + '控制前台结算页是否展示该方式,以及展示时使用的来源键。', + 'Controls whether checkout shows this method and which source key it exposes.' + ) + }} +

+
+ +
+ +
+ + +

+ {{ + localText( + '留空表示由后端使用默认来源;可填 easypay、alipay、wxpay 等来源标识。', + 'Leave blank to let the backend decide. Typical values are easypay, alipay, or wxpay.' + ) + }} +

+
+
+
@@ -2827,7 +3114,14 @@ import { ref, reactive, computed, onMounted } from 'vue' import { useI18n } from 'vue-i18n' import { adminAPI } from '@/api' +import { + appendAuthSourceDefaultsToUpdateRequest, + buildAuthSourceDefaultsState, + normalizeDefaultSubscriptionSettings, +} from '@/api/admin/settings' import type { + AuthSourceDefaultsState, + AuthSourceType, SystemSettings, UpdateSettingsRequest, DefaultSubscriptionSetting, @@ -2864,6 +3158,10 @@ const { t, locale } = useI18n() const appStore = useAppStore() const adminSettingsStore = useAdminSettingsStore() +function localText(zh: string, en: string): string { + return locale.value.startsWith('zh') ? zh : en +} + type SettingsTab = 'general' | 'security' | 'users' | 'gateway' | 'payment' | 'email' | 'backup' const activeTab = ref('general') const settingsTabs = [ @@ -2960,6 +3258,12 @@ type SettingsForm = SystemSettings & { turnstile_secret_key: string linuxdo_connect_client_secret: string oidc_connect_client_secret: string + force_email_on_third_party_signup: boolean + payment_visible_method_alipay_source: string + payment_visible_method_wxpay_source: string + payment_visible_method_alipay_enabled: boolean + payment_visible_method_wxpay_enabled: boolean + openai_advanced_scheduler_enabled: boolean } const form = reactive({ @@ -2974,6 +3278,7 @@ const form = reactive({ default_balance: 0, default_concurrency: 1, default_subscriptions: [], + force_email_on_third_party_signup: false, site_name: 'Sub2API', site_logo: '', site_subtitle: 'Subscription to API Conversion Platform', @@ -2983,7 +3288,7 @@ const form = reactive({ home_content: '', backend_mode_enabled: false, hide_ccs_import_button: false, - payment_enabled: false, payment_min_amount: 1, payment_max_amount: 10000, payment_daily_limit: 50000, payment_max_pending_orders: 3, payment_order_timeout_minutes: 30, payment_balance_disabled: false, payment_balance_recharge_multiplier: 1, payment_recharge_fee_rate: 0, payment_enabled_types: [], payment_help_image_url: '', payment_help_text: '', payment_product_name_prefix: '', payment_product_name_suffix: '', payment_load_balance_strategy: 'round-robin', payment_cancel_rate_limit_enabled: false, payment_cancel_rate_limit_max: 10, payment_cancel_rate_limit_window: 1, payment_cancel_rate_limit_unit: 'day', payment_cancel_rate_limit_window_mode: 'rolling', + payment_enabled: false, payment_min_amount: 1, payment_max_amount: 10000, payment_daily_limit: 50000, payment_max_pending_orders: 3, payment_order_timeout_minutes: 30, payment_balance_disabled: false, payment_balance_recharge_multiplier: 1, payment_recharge_fee_rate: 0, payment_enabled_types: [], payment_help_image_url: '', payment_help_text: '', payment_product_name_prefix: '', payment_product_name_suffix: '', payment_load_balance_strategy: 'round-robin', payment_cancel_rate_limit_enabled: false, payment_cancel_rate_limit_max: 10, payment_cancel_rate_limit_window: 1, payment_cancel_rate_limit_unit: 'day', payment_cancel_rate_limit_window_mode: 'rolling', payment_visible_method_alipay_source: '', payment_visible_method_wxpay_source: '', payment_visible_method_alipay_enabled: false, payment_visible_method_wxpay_enabled: false, table_default_page_size: tablePageSizeDefault, table_page_size_options: [10, 20, 50, 100], custom_menu_items: [] as Array<{id: string; label: string; icon_svg: string; url: string; visibility: 'user' | 'admin'; sort_order: number}>, @@ -3051,6 +3356,7 @@ const form = reactive({ max_claude_code_version: '', // 分组隔离 allow_ungrouped_key_scheduling: false, + openai_advanced_scheduler_enabled: false, // Gateway forwarding behavior enable_fingerprint_unification: true, enable_metadata_passthrough: false, @@ -3063,6 +3369,74 @@ const form = reactive({ account_quota_notify_emails: [] as NotifyEmailEntry[] }) +const authSourceDefaults = reactive(buildAuthSourceDefaultsState({})) + +const authSourceDefaultsMeta = computed(() => [ + { + source: 'email' as AuthSourceType, + title: localText('邮箱注册', 'Email signup'), + description: localText('适用于邮箱密码注册的新用户默认配额。', 'Default quota grants for email-password signups.') + }, + { + source: 'linuxdo' as AuthSourceType, + title: localText('Linux DO 登录', 'Linux DO signup'), + description: localText('适用于 Linux DO 第三方注册的新用户默认配额。', 'Default quota grants for Linux DO signups.') + }, + { + source: 'oidc' as AuthSourceType, + title: localText('OIDC 登录', 'OIDC signup'), + description: localText('适用于 OIDC 第三方注册的新用户默认配额。', 'Default quota grants for OIDC signups.') + }, + { + source: 'wechat' as AuthSourceType, + title: localText('微信登录', 'WeChat signup'), + description: localText('适用于微信第三方注册的新用户默认配额。', 'Default quota grants for WeChat signups.') + }, +]) + +const paymentVisibleMethodCards = computed(() => [ + { + key: 'alipay' as const, + title: t('payment.methods.alipay'), + enabledField: 'payment_visible_method_alipay_enabled' as const, + sourceField: 'payment_visible_method_alipay_source' as const, + }, + { + key: 'wxpay' as const, + title: t('payment.methods.wxpay'), + enabledField: 'payment_visible_method_wxpay_enabled' as const, + sourceField: 'payment_visible_method_wxpay_source' as const, + }, +]) + +function getPaymentVisibleMethodEnabled(method: 'alipay' | 'wxpay'): boolean { + return method === 'alipay' + ? form.payment_visible_method_alipay_enabled + : form.payment_visible_method_wxpay_enabled +} + +function setPaymentVisibleMethodEnabled(method: 'alipay' | 'wxpay', enabled: boolean) { + if (method === 'alipay') { + form.payment_visible_method_alipay_enabled = enabled + return + } + form.payment_visible_method_wxpay_enabled = enabled +} + +function getPaymentVisibleMethodSource(method: 'alipay' | 'wxpay'): string { + return method === 'alipay' + ? form.payment_visible_method_alipay_source + : form.payment_visible_method_wxpay_source +} + +function setPaymentVisibleMethodSource(method: 'alipay' | 'wxpay', source: string) { + if (method === 'alipay') { + form.payment_visible_method_alipay_source = source + return + } + form.payment_visible_method_wxpay_source = source +} + // Proxies for web search emulation ProxySelector const webSearchProxies = ref([]) @@ -3428,15 +3802,9 @@ async function loadSettings() { (form as Record)[key] = value } } + Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(settings)) form.backend_mode_enabled = settings.backend_mode_enabled - form.default_subscriptions = Array.isArray(settings.default_subscriptions) - ? settings.default_subscriptions - .filter((item) => item.group_id > 0 && item.validity_days > 0) - .map((item) => ({ - group_id: item.group_id, - validity_days: item.validity_days - })) - : [] + form.default_subscriptions = normalizeDefaultSubscriptionSettings(settings.default_subscriptions) registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains( settings.registration_email_suffix_whitelist ) @@ -3471,10 +3839,18 @@ async function loadSubscriptionGroups() { } } +function findNextAvailableSubscriptionGroup( + existingGroupIDs: number[] +): AdminGroup | undefined { + const existing = new Set(existingGroupIDs) + return subscriptionGroups.value.find((group) => !existing.has(group.id)) +} + function addDefaultSubscription() { if (subscriptionGroups.value.length === 0) return - const existing = new Set(form.default_subscriptions.map((item) => item.group_id)) - const candidate = subscriptionGroups.value.find((group) => !existing.has(group.id)) + const candidate = findNextAvailableSubscriptionGroup( + form.default_subscriptions.map((item) => item.group_id) + ) if (!candidate) return form.default_subscriptions.push({ group_id: candidate.id, @@ -3486,6 +3862,36 @@ function removeDefaultSubscription(index: number) { form.default_subscriptions.splice(index, 1) } +function addAuthSourceDefaultSubscription(source: AuthSourceType) { + if (subscriptionGroups.value.length === 0) return + const candidate = findNextAvailableSubscriptionGroup( + authSourceDefaults[source].subscriptions.map((item) => item.group_id) + ) + if (!candidate) return + authSourceDefaults[source].subscriptions.push({ + group_id: candidate.id, + validity_days: 30 + }) +} + +function removeAuthSourceDefaultSubscription(source: AuthSourceType, index: number) { + authSourceDefaults[source].subscriptions.splice(index, 1) +} + +function findDuplicateDefaultSubscription( + subscriptions: DefaultSubscriptionSetting[] +): DefaultSubscriptionSetting | undefined { + const seenGroupIDs = new Set() + + return subscriptions.find((item) => { + if (seenGroupIDs.has(item.group_id)) { + return true + } + seenGroupIDs.add(item.group_id) + return false + }) +} + async function saveSettings() { saving.value = true try { @@ -3520,21 +3926,12 @@ async function saveSettings() { form.table_default_page_size = normalizedTableDefaultPageSize form.table_page_size_options = normalizedTablePageSizeOptions - const normalizedDefaultSubscriptions = form.default_subscriptions - .filter((item) => item.group_id > 0 && item.validity_days > 0) - .map((item: DefaultSubscriptionSetting) => ({ - group_id: item.group_id, - validity_days: Math.min(36500, Math.max(1, Math.floor(item.validity_days))) - })) - - const seenGroupIDs = new Set() - const duplicateDefaultSubscription = normalizedDefaultSubscriptions.find((item) => { - if (seenGroupIDs.has(item.group_id)) { - return true - } - seenGroupIDs.add(item.group_id) - return false - }) + const normalizedDefaultSubscriptions = normalizeDefaultSubscriptionSettings( + form.default_subscriptions + ) + const duplicateDefaultSubscription = findDuplicateDefaultSubscription( + normalizedDefaultSubscriptions + ) if (duplicateDefaultSubscription) { appStore.showError( t('admin.settings.defaults.defaultSubscriptionsDuplicate', { @@ -3544,6 +3941,23 @@ async function saveSettings() { return } + for (const authSource of authSourceDefaultsMeta.value) { + authSourceDefaults[authSource.source].subscriptions = normalizeDefaultSubscriptionSettings( + authSourceDefaults[authSource.source].subscriptions + ) + const duplicate = findDuplicateDefaultSubscription( + authSourceDefaults[authSource.source].subscriptions + ) + if (duplicate) { + appStore.showError( + `${authSource.title}: ${t('admin.settings.defaults.defaultSubscriptionsDuplicate', { + groupId: duplicate.group_id + })}` + ) + return + } + } + // Validate URL fields — novalidate disables browser-native checks, so we validate here const isValidHttpUrl = (url: string): boolean => { if (!url) return true @@ -3571,6 +3985,7 @@ async function saveSettings() { default_balance: form.default_balance, default_concurrency: form.default_concurrency, default_subscriptions: normalizedDefaultSubscriptions, + force_email_on_third_party_signup: form.force_email_on_third_party_signup, site_name: form.site_name, site_logo: form.site_logo, site_subtitle: form.site_subtitle, @@ -3655,6 +4070,11 @@ async function saveSettings() { payment_cancel_rate_limit_window: Number(form.payment_cancel_rate_limit_window) || 1, payment_cancel_rate_limit_unit: form.payment_cancel_rate_limit_unit, payment_cancel_rate_limit_window_mode: form.payment_cancel_rate_limit_window_mode, + payment_visible_method_alipay_source: form.payment_visible_method_alipay_source, + payment_visible_method_wxpay_source: form.payment_visible_method_wxpay_source, + payment_visible_method_alipay_enabled: form.payment_visible_method_alipay_enabled, + payment_visible_method_wxpay_enabled: form.payment_visible_method_wxpay_enabled, + openai_advanced_scheduler_enabled: form.openai_advanced_scheduler_enabled, // Balance & quota notification balance_low_notify_enabled: form.balance_low_notify_enabled, balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0, @@ -3663,12 +4083,15 @@ async function saveSettings() { account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e) => e.email.trim() !== ''), } + appendAuthSourceDefaultsToUpdateRequest(payload, authSourceDefaults) + const updated = await adminAPI.settings.updateSettings(payload) for (const [key, value] of Object.entries(updated)) { if (value !== null && value !== undefined) { (form as Record)[key] = value } } + Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(updated)) registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains( updated.registration_email_suffix_whitelist ) diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue index af48959b..0a125def 100644 --- a/frontend/src/views/auth/LinuxDoCallbackView.vue +++ b/frontend/src/views/auth/LinuxDoCallbackView.vue @@ -11,32 +11,94 @@
-
-

- {{ t('auth.linuxdo.invitationRequired') }} -

-
- -
- -

- {{ invitationError }} -

-
- +
+
+

+ Use LinuxDo profile details +

+

+ Choose whether to apply the nickname or avatar from LinuxDo to this account. +

+
+ + + + +
+
+ + + +
@@ -71,7 +133,12 @@ import { useI18n } from 'vue-i18n' import { AuthLayout } from '@/components/layout' import Icon from '@/components/icons/Icon.vue' import { useAuthStore, useAppStore } from '@/stores' -import { completeLinuxDoOAuthRegistration } from '@/api/auth' +import { + completeLinuxDoOAuthRegistration, + exchangePendingOAuthCompletion, + type OAuthAdoptionDecision, + type PendingOAuthExchangeResponse +} from '@/api/auth' const route = useRoute() const router = useRouter() @@ -85,11 +152,16 @@ const errorMessage = ref('') // Invitation code flow state const needsInvitation = ref(false) -const pendingOAuthToken = ref('') const invitationCode = ref('') const isSubmitting = ref(false) const invitationError = ref('') const redirectTo = ref('/dashboard') +const adoptionRequired = ref(false) +const suggestedDisplayName = ref('') +const suggestedAvatarUrl = ref('') +const adoptDisplayName = ref(true) +const adoptAvatar = ref(true) +const needsAdoptionConfirmation = ref(false) function parseFragmentParams(): URLSearchParams { const raw = typeof window !== 'undefined' ? window.location.hash : '' @@ -106,6 +178,54 @@ function sanitizeRedirectPath(path: string | null | undefined): string { return path } +function currentAdoptionDecision(): OAuthAdoptionDecision { + return { + adoptDisplayName: adoptDisplayName.value, + adoptAvatar: adoptAvatar.value + } +} + +function applyAdoptionSuggestionState(completion: { + adoption_required?: boolean + suggested_display_name?: string + suggested_avatar_url?: string +}) { + adoptionRequired.value = completion.adoption_required === true + suggestedDisplayName.value = completion.suggested_display_name || '' + suggestedAvatarUrl.value = completion.suggested_avatar_url || '' + + if (!suggestedDisplayName.value) { + adoptDisplayName.value = false + } + if (!suggestedAvatarUrl.value) { + adoptAvatar.value = false + } +} + +function hasSuggestedProfile(completion: { + suggested_display_name?: string + suggested_avatar_url?: string +}): boolean { + return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) +} + +async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) { + if (!completion.access_token) { + throw new Error(t('auth.linuxdo.callbackMissingToken')) + } + + if (completion.refresh_token) { + localStorage.setItem('refresh_token', completion.refresh_token) + } + if (completion.expires_in) { + localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000)) + } + + await authStore.setToken(completion.access_token) + appStore.showSuccess(t('auth.loginSuccess')) + await router.replace(redirect) +} + async function handleSubmitInvitation() { invitationError.value = '' if (!invitationCode.value.trim()) return @@ -113,8 +233,8 @@ async function handleSubmitInvitation() { isSubmitting.value = true try { const tokenData = await completeLinuxDoOAuthRegistration( - pendingOAuthToken.value, - invitationCode.value.trim() + invitationCode.value.trim(), + currentAdoptionDecision() ) if (tokenData.refresh_token) { localStorage.setItem('refresh_token', tokenData.refresh_token) @@ -134,63 +254,65 @@ async function handleSubmitInvitation() { } } +async function handleContinueLogin() { + isSubmitting.value = true + try { + const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision()) + await finalizeLogin(completion, redirectTo.value) + } catch (e: unknown) { + const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } + errorMessage.value = + err.response?.data?.detail || + err.response?.data?.message || + err.message || + t('auth.loginFailed') + appStore.showError(errorMessage.value) + needsAdoptionConfirmation.value = false + } finally { + isSubmitting.value = false + } +} + onMounted(async () => { const params = parseFragmentParams() - - const token = params.get('access_token') || '' - const refreshToken = params.get('refresh_token') || '' - const expiresInStr = params.get('expires_in') || '' - const redirect = sanitizeRedirectPath( - params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard' - ) const error = params.get('error') const errorDesc = params.get('error_description') || params.get('error_message') || '' if (error) { - if (error === 'invitation_required') { - pendingOAuthToken.value = params.get('pending_oauth_token') || '' - redirectTo.value = sanitizeRedirectPath(params.get('redirect')) - if (!pendingOAuthToken.value) { - errorMessage.value = t('auth.linuxdo.invalidPendingToken') - appStore.showError(errorMessage.value) - isProcessing.value = false - return - } - needsInvitation.value = true - isProcessing.value = false - return - } errorMessage.value = errorDesc || error appStore.showError(errorMessage.value) isProcessing.value = false return } - if (!token) { - errorMessage.value = t('auth.linuxdo.callbackMissingToken') - appStore.showError(errorMessage.value) - isProcessing.value = false - return - } - try { - // Store refresh token and expires_at (convert to timestamp) if provided - if (refreshToken) { - localStorage.setItem('refresh_token', refreshToken) - } - if (expiresInStr) { - const expiresIn = parseInt(expiresInStr, 10) - if (!isNaN(expiresIn)) { - localStorage.setItem('token_expires_at', String(Date.now() + expiresIn * 1000)) - } + const completion = await exchangePendingOAuthCompletion() + const redirect = sanitizeRedirectPath( + completion.redirect || (route.query.redirect as string | undefined) || '/dashboard' + ) + applyAdoptionSuggestionState(completion) + redirectTo.value = redirect + + if (completion.error === 'invitation_required') { + needsInvitation.value = true + isProcessing.value = false + return } - await authStore.setToken(token) - appStore.showSuccess(t('auth.loginSuccess')) - await router.replace(redirect) + if (adoptionRequired.value && hasSuggestedProfile(completion)) { + needsAdoptionConfirmation.value = true + isProcessing.value = false + return + } + + await finalizeLogin(completion, redirect) } catch (e: unknown) { - const err = e as { message?: string; response?: { data?: { detail?: string } } } - errorMessage.value = err.response?.data?.detail || err.message || t('auth.loginFailed') + const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } + errorMessage.value = + err.response?.data?.detail || + err.response?.data?.message || + err.message || + t('auth.loginFailed') appStore.showError(errorMessage.value) isProcessing.value = false } @@ -209,4 +331,3 @@ onMounted(async () => { transform: translateY(-8px); } - diff --git a/frontend/src/views/auth/LoginView.vue b/frontend/src/views/auth/LoginView.vue index 70b64e3f..fa4ac34c 100644 --- a/frontend/src/views/auth/LoginView.vue +++ b/frontend/src/views/auth/LoginView.vue @@ -11,12 +11,17 @@

-
+
+ (false) const turnstileEnabled = ref(false) const turnstileSiteKey = ref('') const linuxdoOAuthEnabled = ref(false) +const wechatOAuthEnabled = ref(false) const backendModeEnabled = ref(false) const oidcOAuthEnabled = ref(false) const oidcOAuthProviderName = ref('OIDC') @@ -267,6 +274,7 @@ onMounted(async () => { turnstileEnabled.value = settings.turnstile_enabled turnstileSiteKey.value = settings.turnstile_site_key || '' linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled + wechatOAuthEnabled.value = settings.wechat_oauth_enabled backendModeEnabled.value = settings.backend_mode_enabled oidcOAuthEnabled.value = settings.oidc_oauth_enabled oidcOAuthProviderName.value = settings.oidc_oauth_provider_name || 'OIDC' diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue index a6cb6c12..55f8af6e 100644 --- a/frontend/src/views/auth/OidcCallbackView.vue +++ b/frontend/src/views/auth/OidcCallbackView.vue @@ -15,36 +15,99 @@
-
-

- {{ t('auth.oidc.invitationRequired', { providerName }) }} -

-
- -
- -

- {{ invitationError }} -

-
- +
+
+

+ Use {{ providerName }} profile details +

+

+ Choose whether to apply the nickname or avatar from {{ providerName }} to this + account. +

+
+ + + + +
+
+ + + +
@@ -81,7 +144,10 @@ import Icon from '@/components/icons/Icon.vue' import { useAuthStore, useAppStore } from '@/stores' import { completeOIDCOAuthRegistration, - getPublicSettings + exchangePendingOAuthCompletion, + getPublicSettings, + type OAuthAdoptionDecision, + type PendingOAuthExchangeResponse } from '@/api/auth' const route = useRoute() @@ -95,12 +161,17 @@ const isProcessing = ref(true) const errorMessage = ref('') const needsInvitation = ref(false) -const pendingOAuthToken = ref('') const invitationCode = ref('') const isSubmitting = ref(false) const invitationError = ref('') const redirectTo = ref('/dashboard') const providerName = ref('OIDC') +const adoptionRequired = ref(false) +const suggestedDisplayName = ref('') +const suggestedAvatarUrl = ref('') +const adoptDisplayName = ref(true) +const adoptAvatar = ref(true) +const needsAdoptionConfirmation = ref(false) function parseFragmentParams(): URLSearchParams { const raw = typeof window !== 'undefined' ? window.location.hash : '' @@ -129,6 +200,54 @@ async function loadProviderName() { } } +function currentAdoptionDecision(): OAuthAdoptionDecision { + return { + adoptDisplayName: adoptDisplayName.value, + adoptAvatar: adoptAvatar.value + } +} + +function applyAdoptionSuggestionState(completion: { + adoption_required?: boolean + suggested_display_name?: string + suggested_avatar_url?: string +}) { + adoptionRequired.value = completion.adoption_required === true + suggestedDisplayName.value = completion.suggested_display_name || '' + suggestedAvatarUrl.value = completion.suggested_avatar_url || '' + + if (!suggestedDisplayName.value) { + adoptDisplayName.value = false + } + if (!suggestedAvatarUrl.value) { + adoptAvatar.value = false + } +} + +function hasSuggestedProfile(completion: { + suggested_display_name?: string + suggested_avatar_url?: string +}): boolean { + return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) +} + +async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) { + if (!completion.access_token) { + throw new Error(t('auth.oidc.callbackMissingToken')) + } + + if (completion.refresh_token) { + localStorage.setItem('refresh_token', completion.refresh_token) + } + if (completion.expires_in) { + localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000)) + } + + await authStore.setToken(completion.access_token) + appStore.showSuccess(t('auth.loginSuccess')) + await router.replace(redirect) +} + async function handleSubmitInvitation() { invitationError.value = '' if (!invitationCode.value.trim()) return @@ -136,8 +255,8 @@ async function handleSubmitInvitation() { isSubmitting.value = true try { const tokenData = await completeOIDCOAuthRegistration( - pendingOAuthToken.value, - invitationCode.value.trim() + invitationCode.value.trim(), + currentAdoptionDecision() ) if (tokenData.refresh_token) { localStorage.setItem('refresh_token', tokenData.refresh_token) @@ -157,63 +276,67 @@ async function handleSubmitInvitation() { } } +async function handleContinueLogin() { + isSubmitting.value = true + try { + const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision()) + await finalizeLogin(completion, redirectTo.value) + } catch (e: unknown) { + const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } + errorMessage.value = + err.response?.data?.detail || + err.response?.data?.message || + err.message || + t('auth.loginFailed') + appStore.showError(errorMessage.value) + needsAdoptionConfirmation.value = false + } finally { + isSubmitting.value = false + } +} + onMounted(async () => { void loadProviderName() const params = parseFragmentParams() - const token = params.get('access_token') || '' - const refreshToken = params.get('refresh_token') || '' - const expiresInStr = params.get('expires_in') || '' - const redirect = sanitizeRedirectPath( - params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard' - ) const error = params.get('error') const errorDesc = params.get('error_description') || params.get('error_message') || '' if (error) { - if (error === 'invitation_required') { - pendingOAuthToken.value = params.get('pending_oauth_token') || '' - redirectTo.value = sanitizeRedirectPath(params.get('redirect')) - if (!pendingOAuthToken.value) { - errorMessage.value = t('auth.oidc.invalidPendingToken') - appStore.showError(errorMessage.value) - isProcessing.value = false - return - } - needsInvitation.value = true - isProcessing.value = false - return - } errorMessage.value = errorDesc || error appStore.showError(errorMessage.value) isProcessing.value = false return } - if (!token) { - errorMessage.value = t('auth.oidc.callbackMissingToken') - appStore.showError(errorMessage.value) - isProcessing.value = false - return - } - try { - if (refreshToken) { - localStorage.setItem('refresh_token', refreshToken) - } - if (expiresInStr) { - const expiresIn = parseInt(expiresInStr, 10) - if (!isNaN(expiresIn)) { - localStorage.setItem('token_expires_at', String(Date.now() + expiresIn * 1000)) - } + const completion = await exchangePendingOAuthCompletion() + const redirect = sanitizeRedirectPath( + completion.redirect || (route.query.redirect as string | undefined) || '/dashboard' + ) + applyAdoptionSuggestionState(completion) + redirectTo.value = redirect + + if (completion.error === 'invitation_required') { + needsInvitation.value = true + isProcessing.value = false + return } - await authStore.setToken(token) - appStore.showSuccess(t('auth.loginSuccess')) - await router.replace(redirect) + if (adoptionRequired.value && hasSuggestedProfile(completion)) { + needsAdoptionConfirmation.value = true + isProcessing.value = false + return + } + + await finalizeLogin(completion, redirect) } catch (e: unknown) { - const err = e as { message?: string; response?: { data?: { detail?: string } } } - errorMessage.value = err.response?.data?.detail || err.message || t('auth.loginFailed') + const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } + errorMessage.value = + err.response?.data?.detail || + err.response?.data?.message || + err.message || + t('auth.loginFailed') appStore.showError(errorMessage.value) isProcessing.value = false } diff --git a/frontend/src/views/auth/RegisterView.vue b/frontend/src/views/auth/RegisterView.vue index bc8b8dce..378f9d8a 100644 --- a/frontend/src/views/auth/RegisterView.vue +++ b/frontend/src/views/auth/RegisterView.vue @@ -11,12 +11,17 @@

-
+
+ (false) const turnstileSiteKey = ref('') const siteName = ref('Sub2API') const linuxdoOAuthEnabled = ref(false) +const wechatOAuthEnabled = ref(false) const oidcOAuthEnabled = ref(false) const oidcOAuthProviderName = ref('OIDC') const registrationEmailSuffixWhitelist = ref([]) @@ -397,6 +404,7 @@ onMounted(async () => { turnstileSiteKey.value = settings.turnstile_site_key || '' siteName.value = settings.site_name || 'Sub2API' linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled + wechatOAuthEnabled.value = settings.wechat_oauth_enabled oidcOAuthEnabled.value = settings.oidc_oauth_enabled oidcOAuthProviderName.value = settings.oidc_oauth_provider_name || 'OIDC' registrationEmailSuffixWhitelist.value = normalizeRegistrationEmailSuffixWhitelist( diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue new file mode 100644 index 00000000..407b395b --- /dev/null +++ b/frontend/src/views/auth/WechatCallbackView.vue @@ -0,0 +1,361 @@ + + + + + diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts new file mode 100644 index 00000000..60a40474 --- /dev/null +++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts @@ -0,0 +1,180 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' + +import LinuxDoCallbackView from '../LinuxDoCallbackView.vue' + +const replace = vi.fn() +const showSuccess = vi.fn() +const showError = vi.fn() +const setToken = vi.fn() +const exchangePendingOAuthCompletion = vi.fn() +const completeLinuxDoOAuthRegistration = vi.fn() + +vi.mock('vue-router', () => ({ + useRoute: () => ({ + query: {} + }), + useRouter: () => ({ + replace + }) +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string) => key + }) + } +}) + +vi.mock('@/stores', () => ({ + useAuthStore: () => ({ + setToken + }), + useAppStore: () => ({ + showSuccess, + showError + }) +})) + +vi.mock('@/api/auth', () => ({ + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), + completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args) +})) + +describe('LinuxDoCallbackView', () => { + beforeEach(() => { + replace.mockReset() + showSuccess.mockReset() + showError.mockReset() + setToken.mockReset() + exchangePendingOAuthCompletion.mockReset() + completeLinuxDoOAuthRegistration.mockReset() + }) + + it('does not send adoption decisions during the initial exchange', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true + }) + setToken.mockResolvedValue({}) + + mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + }) + + it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + .mockResolvedValueOnce({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('LinuxDo Nick') + expect(setToken).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + await checkboxes[1].setValue(false) + + const buttons = wrapper.findAll('button') + expect(buttons).toHaveLength(1) + await buttons[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(2) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(1) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: false + }) + expect(setToken).toHaveBeenCalledWith('access-token') + expect(replace).toHaveBeenCalledWith('/dashboard') + }) + + it('renders adoption choices for invitation flow and submits the selected values', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'invitation_required', + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + completeLinuxDoOAuthRegistration.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + token_type: 'Bearer' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('LinuxDo Nick') + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + + await checkboxes[0].setValue(false) + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + + expect(completeLinuxDoOAuthRegistration).toHaveBeenCalledWith('invite-code', { + adoptDisplayName: false, + adoptAvatar: true + }) + }) +}) diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts new file mode 100644 index 00000000..299c0746 --- /dev/null +++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts @@ -0,0 +1,191 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' + +import OidcCallbackView from '../OidcCallbackView.vue' + +const replace = vi.fn() +const showSuccess = vi.fn() +const showError = vi.fn() +const setToken = vi.fn() +const exchangePendingOAuthCompletion = vi.fn() +const completeOIDCOAuthRegistration = vi.fn() +const getPublicSettings = vi.fn() + +vi.mock('vue-router', () => ({ + useRoute: () => ({ + query: {} + }), + useRouter: () => ({ + replace + }) +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (!params?.providerName) { + return key + } + return `${key}:${params.providerName}` + } + }) + } +}) + +vi.mock('@/stores', () => ({ + useAuthStore: () => ({ + setToken + }), + useAppStore: () => ({ + showSuccess, + showError + }) +})) + +vi.mock('@/api/auth', () => ({ + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), + completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args), + getPublicSettings: (...args: any[]) => getPublicSettings(...args) +})) + +describe('OidcCallbackView', () => { + beforeEach(() => { + replace.mockReset() + showSuccess.mockReset() + showError.mockReset() + setToken.mockReset() + exchangePendingOAuthCompletion.mockReset() + completeOIDCOAuthRegistration.mockReset() + getPublicSettings.mockReset() + getPublicSettings.mockResolvedValue({ + oidc_oauth_provider_name: 'ExampleID' + }) + }) + + it('does not send adoption decisions during the initial exchange', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true + }) + setToken.mockResolvedValue({}) + + mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + }) + + it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + .mockResolvedValueOnce({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('OIDC Nick') + expect(setToken).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + await checkboxes[0].setValue(false) + + const buttons = wrapper.findAll('button') + expect(buttons).toHaveLength(1) + await buttons[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(2) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(1) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: false, + adoptAvatar: true + }) + expect(setToken).toHaveBeenCalledWith('access-token') + expect(replace).toHaveBeenCalledWith('/dashboard') + }) + + it('renders adoption choices for invitation flow and submits the selected values', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'invitation_required', + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + completeOIDCOAuthRegistration.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + token_type: 'Bearer' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('OIDC Nick') + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + + await checkboxes[1].setValue(false) + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + + expect(completeOIDCOAuthRegistration).toHaveBeenCalledWith('invite-code', { + adoptDisplayName: true, + adoptAvatar: false + }) + }) +}) diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts new file mode 100644 index 00000000..a9e2ada2 --- /dev/null +++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts @@ -0,0 +1,241 @@ +import { flushPromises, mount } from '@vue/test-utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import WechatCallbackView from '@/views/auth/WechatCallbackView.vue' + +const { + postMock, + replaceMock, + setTokenMock, + showSuccessMock, + showErrorMock, + routeState, +} = vi.hoisted(() => ({ + postMock: vi.fn(), + replaceMock: vi.fn(), + setTokenMock: vi.fn(), + showSuccessMock: vi.fn(), + showErrorMock: vi.fn(), + routeState: { + query: {} as Record, + }, +})) + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, + useRouter: () => ({ + replace: replaceMock, + }), +})) + +vi.mock('vue-i18n', () => ({ + createI18n: () => ({ + global: { + t: (key: string) => key, + }, + }), + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'auth.oidc.callbackTitle') { + return `Signing you in with ${params?.providerName ?? ''}`.trim() + } + if (key === 'auth.oidc.callbackProcessing') { + return `Completing login with ${params?.providerName ?? ''}`.trim() + } + if (key === 'auth.oidc.invitationRequired') { + return `${params?.providerName ?? ''} invitation required`.trim() + } + if (key === 'auth.oidc.completeRegistration') { + return 'Complete registration' + } + if (key === 'auth.oidc.completing') { + return 'Completing' + } + if (key === 'auth.oidc.backToLogin') { + return 'Back to login' + } + if (key === 'auth.invitationCodePlaceholder') { + return 'Invitation code' + } + if (key === 'auth.loginSuccess') { + return 'Login success' + } + if (key === 'auth.loginFailed') { + return 'Login failed' + } + if (key === 'auth.oidc.callbackHint') { + return 'Callback hint' + } + if (key === 'auth.oidc.callbackMissingToken') { + return 'Missing login token' + } + if (key === 'auth.oidc.completeRegistrationFailed') { + return 'Complete registration failed' + } + return key + }, + }), +})) + +vi.mock('@/stores', () => ({ + useAuthStore: () => ({ + setToken: setTokenMock, + }), + useAppStore: () => ({ + showSuccess: showSuccessMock, + showError: showErrorMock, + }), +})) + +vi.mock('@/api/client', () => ({ + apiClient: { + post: postMock, + }, +})) + +describe('WechatCallbackView', () => { + beforeEach(() => { + postMock.mockReset() + replaceMock.mockReset() + setTokenMock.mockReset() + showSuccessMock.mockReset() + showErrorMock.mockReset() + routeState.query = {} + localStorage.clear() + }) + + it('does not send adoption decisions during the initial exchange', async () => { + postMock.mockResolvedValueOnce({ + data: { + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true, + }, + }) + setTokenMock.mockResolvedValue({}) + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(postMock).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {}) + expect(postMock).toHaveBeenCalledTimes(1) + }) + + it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { + postMock + .mockResolvedValueOnce({ + data: { + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }, + }) + .mockResolvedValueOnce({ + data: { + access_token: 'wechat-access-token', + refresh_token: 'wechat-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + redirect: '/dashboard', + }, + }) + setTokenMock.mockResolvedValue({}) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(wrapper.text()).toContain('WeChat Nick') + expect(setTokenMock).not.toHaveBeenCalled() + expect(replaceMock).not.toHaveBeenCalled() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[1].setValue(false) + + const buttons = wrapper.findAll('button') + expect(buttons).toHaveLength(1) + await buttons[0].trigger('click') + await flushPromises() + + expect(postMock).toHaveBeenNthCalledWith(1, '/auth/oauth/pending/exchange', {}) + expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/pending/exchange', { + adopt_display_name: true, + adopt_avatar: false, + }) + expect(setTokenMock).toHaveBeenCalledWith('wechat-access-token') + expect(replaceMock).toHaveBeenCalledWith('/dashboard') + expect(localStorage.getItem('refresh_token')).toBe('wechat-refresh-token') + }) + + it('renders adoption choices for invitation flow and submits the selected values', async () => { + postMock + .mockResolvedValueOnce({ + data: { + error: 'invitation_required', + redirect: '/subscriptions', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }, + }) + .mockResolvedValueOnce({ + data: { + access_token: 'wechat-invite-token', + refresh_token: 'wechat-invite-refresh', + expires_in: 600, + token_type: 'Bearer', + }, + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(wrapper.text()).toContain('WeChat Nick') + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[0].setValue(false) + await wrapper.get('input[type="text"]').setValue(' INVITE-CODE ') + await wrapper.get('button').trigger('click') + await flushPromises() + + expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/wechat/complete-registration', { + invitation_code: 'INVITE-CODE', + adopt_display_name: false, + adopt_avatar: true, + }) + expect(setTokenMock).toHaveBeenCalledWith('wechat-invite-token') + expect(replaceMock).toHaveBeenCalledWith('/subscriptions') + }) +}) diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue index e91df5da..838f3000 100644 --- a/frontend/src/views/user/PaymentView.vue +++ b/frontend/src/views/user/PaymentView.vue @@ -23,20 +23,7 @@ :order-type="paymentState.orderType" @done="onPaymentDone" @success="onPaymentSuccess" - /> - - @@ -265,7 +252,7 @@ diff --git a/frontend/src/components/user/profile/ProfileInfoCard.vue b/frontend/src/components/user/profile/ProfileInfoCard.vue index b6f6022d..e82ae229 100644 --- a/frontend/src/components/user/profile/ProfileInfoCard.vue +++ b/frontend/src/components/user/profile/ProfileInfoCard.vue @@ -4,11 +4,16 @@ class="border-b border-gray-100 bg-gradient-to-r from-primary-500/10 to-primary-600/5 px-6 py-5 dark:border-dark-700 dark:from-primary-500/20 dark:to-primary-600/10" >
-
- {{ user?.email?.charAt(0).toUpperCase() || 'U' }} + + {{ avatarInitial }}

@@ -41,18 +46,163 @@ {{ user.username }}

+ +
+
+ + {{ hint.text }} +
+
+ +
diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts new file mode 100644 index 00000000..1c9531e3 --- /dev/null +++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts @@ -0,0 +1,120 @@ +import { mount } from '@vue/test-utils' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import ProfileIdentityBindingsSection from '@/components/user/profile/ProfileIdentityBindingsSection.vue' +import type { User } from '@/types' + +const routeState = vi.hoisted(() => ({ + fullPath: '/profile', +})) + +const locationState = vi.hoisted(() => ({ + current: { href: 'http://localhost/profile' } as { href: string }, +})) + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, +})) + +vi.mock('vue-i18n', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'profile.authBindings.title') return 'Connected sign-in methods' + if (key === 'profile.authBindings.description') return 'Manage bound providers' + if (key === 'profile.authBindings.status.bound') return 'Bound' + if (key === 'profile.authBindings.status.notBound') return 'Not bound' + if (key === 'profile.authBindings.providers.email') return 'Email' + if (key === 'profile.authBindings.providers.linuxdo') return 'LinuxDo' + if (key === 'profile.authBindings.providers.wechat') return 'WeChat' + if (key === 'profile.authBindings.providers.oidc') return params?.providerName || 'OIDC' + if (key === 'profile.authBindings.bindAction') return `Bind ${params?.providerName || ''}`.trim() + return key + }, + }), + } +}) + +function createUser(overrides: Partial = {}): User { + return { + id: 7, + username: 'alice', + email: 'alice@example.com', + role: 'user', + balance: 10, + concurrency: 2, + status: 'active', + allowed_groups: null, + balance_notify_enabled: true, + balance_notify_threshold: null, + balance_notify_extra_emails: [], + created_at: '2026-04-20T00:00:00Z', + updated_at: '2026-04-20T00:00:00Z', + ...overrides, + } +} + +describe('ProfileIdentityBindingsSection', () => { + beforeEach(() => { + routeState.fullPath = '/profile' + locationState.current = { href: 'http://localhost/profile' } + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState.current, + }) + Object.defineProperty(window.navigator, 'userAgent', { + configurable: true, + value: 'Mozilla/5.0', + }) + }) + + afterEach(() => { + vi.unstubAllGlobals() + }) + + it('renders provider binding states and provider-specific bind actions', () => { + const wrapper = mount(ProfileIdentityBindingsSection, { + props: { + user: createUser({ + auth_bindings: { + email: { bound: true }, + linuxdo: { bound: true }, + oidc: { bound: false }, + wechat: false, + }, + }), + linuxdoEnabled: true, + oidcEnabled: true, + oidcProviderName: 'ExampleID', + wechatEnabled: true, + }, + }) + + expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Bound') + expect(wrapper.get('[data-testid="profile-binding-linuxdo-status"]').text()).toBe('Bound') + expect(wrapper.get('[data-testid="profile-binding-oidc-status"]').text()).toBe('Not bound') + expect(wrapper.get('[data-testid="profile-binding-oidc-action"]').text()).toBe( + 'Bind ExampleID' + ) + expect(wrapper.get('[data-testid="profile-binding-wechat-action"]').text()).toBe('Bind WeChat') + }) + + it('starts the WeChat bind flow for the current profile page', async () => { + const wrapper = mount(ProfileIdentityBindingsSection, { + props: { + user: createUser(), + linuxdoEnabled: false, + oidcEnabled: false, + wechatEnabled: true, + }, + }) + + await wrapper.get('[data-testid="profile-binding-wechat-action"]').trigger('click') + + expect(locationState.current.href).toContain('/api/v1/auth/oauth/wechat/start?') + expect(locationState.current.href).toContain('mode=open') + expect(locationState.current.href).toContain('intent=bind_current_user') + expect(locationState.current.href).toContain('redirect=%2Fprofile') + }) +}) diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 684c196f..7d058a74 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -940,6 +940,26 @@ export default { maxEmailsReached: 'Maximum number of notification emails reached', unverified: 'Unverified', verified: 'Verified', + }, + authBindings: { + title: 'Connected Sign-In Methods', + description: 'View current bindings and connect another provider to this account.', + bindAction: 'Bind {providerName}', + bindSuccess: 'Account linked successfully', + status: { + bound: 'Bound', + notBound: 'Not bound', + }, + providers: { + email: 'Email', + linuxdo: 'LinuxDo', + oidc: '{providerName}', + wechat: 'WeChat', + }, + source: { + avatar: 'Avatar is currently synced from {providerName}', + username: 'Nickname is currently synced from {providerName}', + }, } }, diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 2a4c69a5..6dd74334 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -944,6 +944,26 @@ export default { maxEmailsReached: '已达到通知邮箱数量上限', unverified: '未验证', verified: '已验证', + }, + authBindings: { + title: '登录方式绑定', + description: '查看当前绑定状态,并将更多第三方登录方式关联到这个账号。', + bindAction: '绑定 {providerName}', + bindSuccess: '账号绑定成功', + status: { + bound: '已绑定', + notBound: '未绑定', + }, + providers: { + email: '邮箱', + linuxdo: 'LinuxDo', + oidc: '{providerName}', + wechat: '微信', + }, + source: { + avatar: '头像当前来自 {providerName}', + username: '昵称当前来自 {providerName}', + }, } }, diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 9c9722a9..a19d6c26 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -34,10 +34,47 @@ export interface NotifyEmailEntry { // ==================== User & Auth Types ==================== +export type UserAuthProvider = 'email' | 'linuxdo' | 'oidc' | 'wechat' + +export interface UserAuthBindingStatus { + bound?: boolean + provider?: UserAuthProvider | string + provider_key?: string | null + provider_subject?: string | null + issuer?: string | null + label?: string | null + provider_label?: string | null + metadata?: Record +} + +export interface UserProfileSourceContext { + provider?: UserAuthProvider | string + source?: string | null + label?: string | null + provider_label?: string | null +} + export interface User { id: number username: string email: string + avatar_url?: string | null + avatar_source?: string | UserProfileSourceContext | null + username_source?: string | UserProfileSourceContext | null + display_name_source?: string | UserProfileSourceContext | null + nickname_source?: string | UserProfileSourceContext | null + profile_sources?: { + avatar?: string | UserProfileSourceContext | null + username?: string | UserProfileSourceContext | null + display_name?: string | UserProfileSourceContext | null + nickname?: string | UserProfileSourceContext | null + } + auth_bindings?: Partial> + identity_bindings?: Partial> + email_bound?: boolean + linuxdo_bound?: boolean + oidc_bound?: boolean + wechat_bound?: boolean role: 'admin' | 'user' // User role for authorization balance: number // User balance for API usage concurrency: number // Allowed concurrent requests diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue index 0a125def..6dc8f242 100644 --- a/frontend/src/views/auth/LinuxDoCallbackView.vue +++ b/frontend/src/views/auth/LinuxDoCallbackView.vue @@ -136,6 +136,9 @@ import { useAuthStore, useAppStore } from '@/stores' import { completeLinuxDoOAuthRegistration, exchangePendingOAuthCompletion, + getOAuthCompletionKind, + isOAuthLoginCompletion, + persistOAuthTokenContext, type OAuthAdoptionDecision, type PendingOAuthExchangeResponse } from '@/api/auth' @@ -162,6 +165,7 @@ const suggestedAvatarUrl = ref('') const adoptDisplayName = ref(true) const adoptAvatar = ref(true) const needsAdoptionConfirmation = ref(false) +const bindSuccessMessage = t('profile.authBindings.bindSuccess') function parseFragmentParams(): URLSearchParams { const raw = typeof window !== 'undefined' ? window.location.hash : '' @@ -209,18 +213,19 @@ function hasSuggestedProfile(completion: { return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) } -async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) { - if (!completion.access_token) { +async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) { + if (getOAuthCompletionKind(completion) === 'bind') { + const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile') + appStore.showSuccess(bindSuccessMessage) + await router.replace(bindRedirect) + return + } + + if (!isOAuthLoginCompletion(completion)) { throw new Error(t('auth.linuxdo.callbackMissingToken')) } - if (completion.refresh_token) { - localStorage.setItem('refresh_token', completion.refresh_token) - } - if (completion.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000)) - } - + persistOAuthTokenContext(completion) await authStore.setToken(completion.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirect) @@ -236,12 +241,7 @@ async function handleSubmitInvitation() { invitationCode.value.trim(), currentAdoptionDecision() ) - if (tokenData.refresh_token) { - localStorage.setItem('refresh_token', tokenData.refresh_token) - } - if (tokenData.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + tokenData.expires_in * 1000)) - } + persistOAuthTokenContext(tokenData) await authStore.setToken(tokenData.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirectTo.value) @@ -258,7 +258,7 @@ async function handleContinueLogin() { isSubmitting.value = true try { const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision()) - await finalizeLogin(completion, redirectTo.value) + await finalizeCompletion(completion, redirectTo.value) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = @@ -305,7 +305,7 @@ onMounted(async () => { return } - await finalizeLogin(completion, redirect) + await finalizeCompletion(completion, redirect) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue index 55f8af6e..83304226 100644 --- a/frontend/src/views/auth/OidcCallbackView.vue +++ b/frontend/src/views/auth/OidcCallbackView.vue @@ -145,7 +145,10 @@ import { useAuthStore, useAppStore } from '@/stores' import { completeOIDCOAuthRegistration, exchangePendingOAuthCompletion, + getOAuthCompletionKind, getPublicSettings, + isOAuthLoginCompletion, + persistOAuthTokenContext, type OAuthAdoptionDecision, type PendingOAuthExchangeResponse } from '@/api/auth' @@ -172,6 +175,7 @@ const suggestedAvatarUrl = ref('') const adoptDisplayName = ref(true) const adoptAvatar = ref(true) const needsAdoptionConfirmation = ref(false) +const bindSuccessMessage = t('profile.authBindings.bindSuccess') function parseFragmentParams(): URLSearchParams { const raw = typeof window !== 'undefined' ? window.location.hash : '' @@ -231,18 +235,19 @@ function hasSuggestedProfile(completion: { return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) } -async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) { - if (!completion.access_token) { +async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) { + if (getOAuthCompletionKind(completion) === 'bind') { + const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile') + appStore.showSuccess(bindSuccessMessage) + await router.replace(bindRedirect) + return + } + + if (!isOAuthLoginCompletion(completion)) { throw new Error(t('auth.oidc.callbackMissingToken')) } - if (completion.refresh_token) { - localStorage.setItem('refresh_token', completion.refresh_token) - } - if (completion.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000)) - } - + persistOAuthTokenContext(completion) await authStore.setToken(completion.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirect) @@ -258,12 +263,7 @@ async function handleSubmitInvitation() { invitationCode.value.trim(), currentAdoptionDecision() ) - if (tokenData.refresh_token) { - localStorage.setItem('refresh_token', tokenData.refresh_token) - } - if (tokenData.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + tokenData.expires_in * 1000)) - } + persistOAuthTokenContext(tokenData) await authStore.setToken(tokenData.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirectTo.value) @@ -280,7 +280,7 @@ async function handleContinueLogin() { isSubmitting.value = true try { const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision()) - await finalizeLogin(completion, redirectTo.value) + await finalizeCompletion(completion, redirectTo.value) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = @@ -329,7 +329,7 @@ onMounted(async () => { return } - await finalizeLogin(completion, redirect) + await finalizeCompletion(completion, redirect) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue index 407b395b..ac0ede4c 100644 --- a/frontend/src/views/auth/WechatCallbackView.vue +++ b/frontend/src/views/auth/WechatCallbackView.vue @@ -140,27 +140,16 @@ import { useRoute, useRouter } from 'vue-router' import { useI18n } from 'vue-i18n' import { AuthLayout } from '@/components/layout' import Icon from '@/components/icons/Icon.vue' -import { apiClient } from '@/api/client' import { useAuthStore, useAppStore } from '@/stores' - -interface OAuthTokenResponse { - access_token: string - refresh_token: string - expires_in: number - token_type: string -} - -interface PendingOAuthExchangeResponse { - access_token?: string - refresh_token?: string - expires_in?: number - token_type?: string - redirect?: string - error?: string - adoption_required?: boolean - suggested_display_name?: string - suggested_avatar_url?: string -} +import { + completeWeChatOAuthRegistration, + exchangePendingOAuthCompletion, + getOAuthCompletionKind, + isOAuthLoginCompletion, + persistOAuthTokenContext, + type OAuthAdoptionDecision, + type PendingOAuthExchangeResponse +} from '@/api/auth' const route = useRoute() const router = useRouter() @@ -182,6 +171,7 @@ const suggestedAvatarUrl = ref('') const adoptDisplayName = ref(true) const adoptAvatar = ref(true) const needsAdoptionConfirmation = ref(false) +const bindSuccessMessage = t('profile.authBindings.bindSuccess') const providerName = 'WeChat' @@ -200,10 +190,10 @@ function sanitizeRedirectPath(path: string | null | undefined): string { return path } -function currentAdoptionDecision(): Record { +function currentAdoptionDecision(): OAuthAdoptionDecision { return { - adopt_display_name: adoptDisplayName.value, - adopt_avatar: adoptAvatar.value, + adoptDisplayName: adoptDisplayName.value, + adoptAvatar: adoptAvatar.value } } @@ -224,49 +214,35 @@ function hasSuggestedProfile(completion: PendingOAuthExchangeResponse): boolean return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) } -async function exchangePendingOAuthCompletion(): Promise { - const { data } = await apiClient.post('/auth/oauth/pending/exchange', {}) - return data -} +async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) { + if (getOAuthCompletionKind(completion) === 'bind') { + const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile') + appStore.showSuccess(bindSuccessMessage) + await router.replace(bindRedirect) + return + } -async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) { - if (!completion.access_token) { + if (!isOAuthLoginCompletion(completion)) { throw new Error(t('auth.oidc.callbackMissingToken')) } - if (completion.refresh_token) { - localStorage.setItem('refresh_token', completion.refresh_token) - } - if (completion.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000)) - } - + persistOAuthTokenContext(completion) await authStore.setToken(completion.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirect) } -async function completeWeChatOAuthRegistration(invitation: string): Promise { - const { data } = await apiClient.post('/auth/oauth/wechat/complete-registration', { - invitation_code: invitation, - ...currentAdoptionDecision(), - }) - return data -} - async function handleSubmitInvitation() { invitationError.value = '' if (!invitationCode.value.trim()) return isSubmitting.value = true try { - const tokenData = await completeWeChatOAuthRegistration(invitationCode.value.trim()) - if (tokenData.refresh_token) { - localStorage.setItem('refresh_token', tokenData.refresh_token) - } - if (tokenData.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + tokenData.expires_in * 1000)) - } + const tokenData = await completeWeChatOAuthRegistration( + invitationCode.value.trim(), + currentAdoptionDecision() + ) + persistOAuthTokenContext(tokenData) await authStore.setToken(tokenData.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirectTo.value) @@ -282,11 +258,8 @@ async function handleSubmitInvitation() { async function handleContinueLogin() { isSubmitting.value = true try { - const { data } = await apiClient.post( - '/auth/oauth/pending/exchange', - currentAdoptionDecision() - ) - await finalizeLogin(data, redirectTo.value) + const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision()) + await finalizeCompletion(completion, redirectTo.value) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = @@ -333,7 +306,7 @@ onMounted(async () => { return } - await finalizeLogin(completion, redirect) + await finalizeCompletion(completion, redirect) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts index 60a40474..7ffdcd19 100644 --- a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts @@ -39,10 +39,14 @@ vi.mock('@/stores', () => ({ }) })) -vi.mock('@/api/auth', () => ({ - exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), - completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args) -})) +vi.mock('@/api/auth', async () => { + const actual = await vi.importActual('@/api/auth') + return { + ...actual, + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), + completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args) + } +}) describe('LinuxDoCallbackView', () => { beforeEach(() => { @@ -132,6 +136,64 @@ describe('LinuxDoCallbackView', () => { expect(replace).toHaveBeenCalledWith('/dashboard') }) + it('treats a completion without token as bind success and returns to profile', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({}) + + mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(setToken).not.toHaveBeenCalled() + expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replace).toHaveBeenCalledWith('/profile') + }) + + it('supports bind completion after adoption confirmation', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + .mockResolvedValueOnce({ + redirect: '/profile/security' + }) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: true + }) + expect(setToken).not.toHaveBeenCalled() + expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replace).toHaveBeenCalledWith('/profile/security') + }) + it('renders adoption choices for invitation flow and submits the selected values', async () => { exchangePendingOAuthCompletion.mockResolvedValue({ error: 'invitation_required', diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts index 299c0746..f8de79f2 100644 --- a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts @@ -45,11 +45,15 @@ vi.mock('@/stores', () => ({ }) })) -vi.mock('@/api/auth', () => ({ - exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), - completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args), - getPublicSettings: (...args: any[]) => getPublicSettings(...args) -})) +vi.mock('@/api/auth', async () => { + const actual = await vi.importActual('@/api/auth') + return { + ...actual, + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), + completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args), + getPublicSettings: (...args: any[]) => getPublicSettings(...args) + } +}) describe('OidcCallbackView', () => { beforeEach(() => { @@ -143,6 +147,43 @@ describe('OidcCallbackView', () => { expect(replace).toHaveBeenCalledWith('/dashboard') }) + it('supports bind completion after adoption confirmation', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + .mockResolvedValueOnce({ + redirect: '/profile' + }) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: true + }) + expect(setToken).not.toHaveBeenCalled() + expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replace).toHaveBeenCalledWith('/profile') + }) + it('renders adoption choices for invitation flow and submits the selected values', async () => { exchangePendingOAuthCompletion.mockResolvedValue({ error: 'invitation_required', diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts index a9e2ada2..896bf15d 100644 --- a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts @@ -3,14 +3,16 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import WechatCallbackView from '@/views/auth/WechatCallbackView.vue' const { - postMock, + exchangePendingOAuthCompletionMock, + completeWeChatOAuthRegistrationMock, replaceMock, setTokenMock, showSuccessMock, showErrorMock, routeState, } = vi.hoisted(() => ({ - postMock: vi.fn(), + exchangePendingOAuthCompletionMock: vi.fn(), + completeWeChatOAuthRegistrationMock: vi.fn(), replaceMock: vi.fn(), setTokenMock: vi.fn(), showSuccessMock: vi.fn(), @@ -86,15 +88,19 @@ vi.mock('@/stores', () => ({ }), })) -vi.mock('@/api/client', () => ({ - apiClient: { - post: postMock, - }, -})) +vi.mock('@/api/auth', async () => { + const actual = await vi.importActual('@/api/auth') + return { + ...actual, + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletionMock(...args), + completeWeChatOAuthRegistration: (...args: any[]) => completeWeChatOAuthRegistrationMock(...args), + } +}) describe('WechatCallbackView', () => { beforeEach(() => { - postMock.mockReset() + exchangePendingOAuthCompletionMock.mockReset() + completeWeChatOAuthRegistrationMock.mockReset() replaceMock.mockReset() setTokenMock.mockReset() showSuccessMock.mockReset() @@ -104,14 +110,12 @@ describe('WechatCallbackView', () => { }) it('does not send adoption decisions during the initial exchange', async () => { - postMock.mockResolvedValueOnce({ - data: { - access_token: 'access-token', - refresh_token: 'refresh-token', - expires_in: 3600, - redirect: '/dashboard', - adoption_required: true, - }, + exchangePendingOAuthCompletionMock.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true, }) setTokenMock.mockResolvedValue({}) @@ -128,28 +132,24 @@ describe('WechatCallbackView', () => { await flushPromises() - expect(postMock).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {}) - expect(postMock).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletionMock).toHaveBeenCalledWith() + expect(exchangePendingOAuthCompletionMock).toHaveBeenCalledTimes(1) }) it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { - postMock + exchangePendingOAuthCompletionMock .mockResolvedValueOnce({ - data: { - redirect: '/dashboard', - adoption_required: true, - suggested_display_name: 'WeChat Nick', - suggested_avatar_url: 'https://cdn.example/wechat.png', - }, + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', }) .mockResolvedValueOnce({ - data: { - access_token: 'wechat-access-token', - refresh_token: 'wechat-refresh-token', - expires_in: 3600, - token_type: 'Bearer', - redirect: '/dashboard', - }, + access_token: 'wechat-access-token', + refresh_token: 'wechat-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + redirect: '/dashboard', }) setTokenMock.mockResolvedValue({}) @@ -179,35 +179,67 @@ describe('WechatCallbackView', () => { await buttons[0].trigger('click') await flushPromises() - expect(postMock).toHaveBeenNthCalledWith(1, '/auth/oauth/pending/exchange', {}) - expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/pending/exchange', { - adopt_display_name: true, - adopt_avatar: false, + expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(1) + expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: false, }) expect(setTokenMock).toHaveBeenCalledWith('wechat-access-token') expect(replaceMock).toHaveBeenCalledWith('/dashboard') expect(localStorage.getItem('refresh_token')).toBe('wechat-refresh-token') }) + it('supports bind completion after adoption confirmation', async () => { + exchangePendingOAuthCompletionMock + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + .mockResolvedValueOnce({ + redirect: '/profile/connections', + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: true, + }) + expect(setTokenMock).not.toHaveBeenCalled() + expect(showSuccessMock).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replaceMock).toHaveBeenCalledWith('/profile/connections') + }) + it('renders adoption choices for invitation flow and submits the selected values', async () => { - postMock - .mockResolvedValueOnce({ - data: { - error: 'invitation_required', - redirect: '/subscriptions', - adoption_required: true, - suggested_display_name: 'WeChat Nick', - suggested_avatar_url: 'https://cdn.example/wechat.png', - }, - }) - .mockResolvedValueOnce({ - data: { - access_token: 'wechat-invite-token', - refresh_token: 'wechat-invite-refresh', - expires_in: 600, - token_type: 'Bearer', - }, - }) + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'invitation_required', + redirect: '/subscriptions', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + completeWeChatOAuthRegistrationMock.mockResolvedValue({ + access_token: 'wechat-invite-token', + refresh_token: 'wechat-invite-refresh', + expires_in: 600, + token_type: 'Bearer', + }) const wrapper = mount(WechatCallbackView, { global: { @@ -230,10 +262,9 @@ describe('WechatCallbackView', () => { await wrapper.get('button').trigger('click') await flushPromises() - expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/wechat/complete-registration', { - invitation_code: 'INVITE-CODE', - adopt_display_name: false, - adopt_avatar: true, + expect(completeWeChatOAuthRegistrationMock).toHaveBeenCalledWith('INVITE-CODE', { + adoptDisplayName: false, + adoptAvatar: true, }) expect(setTokenMock).toHaveBeenCalledWith('wechat-invite-token') expect(replaceMock).toHaveBeenCalledWith('/subscriptions') diff --git a/frontend/src/views/user/ProfileView.vue b/frontend/src/views/user/ProfileView.vue index e7418ebb..f7418be9 100644 --- a/frontend/src/views/user/ProfileView.vue +++ b/frontend/src/views/user/ProfileView.vue @@ -2,18 +2,53 @@
- - - + + +
- -
+ + + +
-
-

{{ t('common.contactSupport') }}

{{ contactInfo }}

+
+ +
+
+

+ {{ t('common.contactSupport') }} +

+

{{ contactInfo }}

+
+ + +
@@ -29,26 +65,78 @@ \ No newline at end of file +onMounted(async () => { + const profileRefresh = authStore.refreshUser().catch((error) => { + console.error('Failed to refresh profile:', error) + }) + + const settingsLoad = authAPI.getPublicSettings() + .then((settings) => { + contactInfo.value = settings.contact_info || '' + balanceLowNotifyEnabled.value = settings.balance_low_notify_enabled ?? false + systemDefaultThreshold.value = settings.balance_low_notify_threshold ?? 0 + linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled ?? false + wechatOAuthEnabled.value = settings.wechat_oauth_enabled ?? false + oidcOAuthEnabled.value = settings.oidc_oauth_enabled ?? false + oidcOAuthProviderName.value = settings.oidc_oauth_provider_name || 'OIDC' + }) + .catch((error) => { + console.error('Failed to load settings:', error) + }) + + await Promise.all([profileRefresh, settingsLoad]) +}) + +const formatCurrency = (value: number) => `$${value.toFixed(2)}` + From 4e0e69154649def4f4149054e09ff03fc8a3e50a Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 18:39:53 +0800 Subject: [PATCH 022/326] feat: apply auth source signup defaults --- .../service/admin_service_apikey_test.go | 3 + .../service/admin_service_delete_test.go | 51 ++++-- backend/internal/service/auth_service.go | 117 +++++++++---- .../service/auth_service_register_test.go | 159 +++++++++++++++++- 4 files changed, 283 insertions(+), 47 deletions(-) diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index b802a9c2..487fb5f1 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -79,6 +79,9 @@ func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *s } func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { + panic("unexpected") +} func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { panic("unexpected") } diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index 323286b0..ac1d8ee7 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -13,15 +13,18 @@ import ( ) type userRepoStub struct { - user *User - getErr error - createErr error - deleteErr error - exists bool - existsErr error - nextID int64 - created []*User - deletedIDs []int64 + user *User + getErr error + createErr error + deleteErr error + exists bool + existsErr error + nextID int64 + created []*User + updated []*User + deletedIDs []int64 + usersByEmail map[string]*User + getByEmailErr error } func (s *userRepoStub) Create(ctx context.Context, user *User) error { @@ -32,6 +35,11 @@ func (s *userRepoStub) Create(ctx context.Context, user *User) error { user.ID = s.nextID } s.created = append(s.created, user) + if s.usersByEmail == nil { + s.usersByEmail = make(map[string]*User) + } + s.usersByEmail[user.Email] = user + s.user = user return nil } @@ -46,7 +54,18 @@ func (s *userRepoStub) GetByID(ctx context.Context, id int64) (*User, error) { } func (s *userRepoStub) GetByEmail(ctx context.Context, email string) (*User, error) { - panic("unexpected GetByEmail call") + if s.getByEmailErr != nil { + return nil, s.getByEmailErr + } + if s.usersByEmail != nil { + if user, ok := s.usersByEmail[email]; ok { + return user, nil + } + } + if s.user != nil && s.user.Email == email { + return s.user, nil + } + return nil, ErrUserNotFound } func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) { @@ -54,7 +73,13 @@ func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) { } func (s *userRepoStub) Update(ctx context.Context, user *User) error { - panic("unexpected Update call") + s.updated = append(s.updated, user) + if s.usersByEmail == nil { + s.usersByEmail = make(map[string]*User) + } + s.usersByEmail[user.Email] = user + s.user = user + return nil } func (s *userRepoStub) Delete(ctx context.Context, id int64) error { @@ -113,6 +138,10 @@ func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64 panic("unexpected AddGroupToAllowedGroups call") } +func (s *userRepoStub) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) { + panic("unexpected ListUserAuthIdentities call") +} + func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { panic("unexpected UpdateTotpSecret call") } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 962009ce..40753139 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -78,6 +78,12 @@ type DefaultSubscriptionAssigner interface { AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) } +type signupGrantPlan struct { + Balance float64 + Concurrency int + Subscriptions []DefaultSubscriptionSetting +} + // NewAuthService 创建认证服务实例 func NewAuthService( entClient *dbent.Client, @@ -187,21 +193,15 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw return "", nil, fmt.Errorf("hash password: %w", err) } - // 获取默认配置 - defaultBalance := s.cfg.Default.UserBalance - defaultConcurrency := s.cfg.Default.UserConcurrency - if s.settingService != nil { - defaultBalance = s.settingService.GetDefaultBalance(ctx) - defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) - } + grantPlan := s.resolveSignupGrantPlan(ctx, "email") // 创建用户 user := &User{ Email: email, PasswordHash: hashedPassword, Role: RoleUser, - Balance: defaultBalance, - Concurrency: defaultConcurrency, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, Status: StatusActive, } @@ -214,7 +214,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw return "", nil, ErrServiceUnavailable } s.postAuthUserBootstrap(ctx, user, "email", true) - s.assignDefaultSubscriptions(ctx, user.ID) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") // 标记邀请码为已使用(如果使用了邀请码) if invitationRedeemCode != nil { @@ -479,21 +479,16 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username return "", nil, fmt.Errorf("hash password: %w", err) } - // 新用户默认值。 - defaultBalance := s.cfg.Default.UserBalance - defaultConcurrency := s.cfg.Default.UserConcurrency - if s.settingService != nil { - defaultBalance = s.settingService.GetDefaultBalance(ctx) - defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) - } + signupSource := inferLegacySignupSource(email) + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) newUser := &User{ Email: email, Username: username, PasswordHash: hashedPassword, Role: RoleUser, - Balance: defaultBalance, - Concurrency: defaultConcurrency, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, Status: StatusActive, } @@ -511,8 +506,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username } } else { user = newUser - s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) - s.assignDefaultSubscriptions(ctx, user.ID) + s.postAuthUserBootstrap(ctx, user, signupSource, true) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") } } else { logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) @@ -596,20 +591,16 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, fmt.Errorf("hash password: %w", err) } - defaultBalance := s.cfg.Default.UserBalance - defaultConcurrency := s.cfg.Default.UserConcurrency - if s.settingService != nil { - defaultBalance = s.settingService.GetDefaultBalance(ctx) - defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) - } + signupSource := inferLegacySignupSource(email) + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) newUser := &User{ Email: email, Username: username, PasswordHash: hashedPassword, Role: RoleUser, - Balance: defaultBalance, - Concurrency: defaultConcurrency, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, Status: StatusActive, } @@ -642,8 +633,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, ErrServiceUnavailable } user = newUser - s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) - s.assignDefaultSubscriptions(ctx, user.ID) + s.postAuthUserBootstrap(ctx, user, signupSource, true) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") } } else { if err := s.userRepo.Create(ctx, newUser); err != nil { @@ -659,8 +650,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema } } else { user = newUser - s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) - s.assignDefaultSubscriptions(ctx, user.ID) + s.postAuthUserBootstrap(ctx, user, signupSource, true) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") if invitationRedeemCode != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { return nil, nil, ErrInvitationCodeInvalid @@ -694,22 +685,78 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema } func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { + if s.settingService == nil { + return + } + s.assignSubscriptions(ctx, userID, s.settingService.GetDefaultSubscriptions(ctx), "auto assigned by default user subscriptions setting") +} + +func (s *AuthService) assignSubscriptions(ctx context.Context, userID int64, items []DefaultSubscriptionSetting, notes string) { if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { return } - items := s.settingService.GetDefaultSubscriptions(ctx) for _, item := range items { if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ UserID: userID, GroupID: item.GroupID, ValidityDays: item.ValidityDays, - Notes: "auto assigned by default user subscriptions setting", + Notes: notes, }); err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) } } } +func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource string) signupGrantPlan { + plan := signupGrantPlan{} + if s != nil && s.cfg != nil { + plan.Balance = s.cfg.Default.UserBalance + plan.Concurrency = s.cfg.Default.UserConcurrency + } + if s == nil || s.settingService == nil { + return plan + } + + plan.Balance = s.settingService.GetDefaultBalance(ctx) + plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx) + plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx) + + defaults, err := s.settingService.GetAuthSourceDefaultSettings(ctx) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err) + return plan + } + + providerDefaults, ok := authSourceSignupSettings(defaults, signupSource) + if !ok || !providerDefaults.GrantOnSignup { + return plan + } + + plan.Balance = providerDefaults.Balance + plan.Concurrency = providerDefaults.Concurrency + plan.Subscriptions = providerDefaults.Subscriptions + return plan +} + +func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource string) (ProviderDefaultGrantSettings, bool) { + if defaults == nil { + return ProviderDefaultGrantSettings{}, false + } + + switch strings.ToLower(strings.TrimSpace(signupSource)) { + case "email": + return defaults.Email, true + case "linuxdo": + return defaults.LinuxDo, true + case "oidc": + return defaults.OIDC, true + case "wechat": + return defaults.WeChat, true + default: + return ProviderDefaultGrantSettings{}, false + } +} + func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) { if user == nil || user.ID <= 0 { return diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index 103bafe7..901b3db3 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -37,7 +37,16 @@ func (s *settingRepoStub) Set(ctx context.Context, key, value string) error { } func (s *settingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { - panic("unexpected GetMultiple call") + if s.err != nil { + return nil, s.err + } + result := make(map[string]string, len(keys)) + for _, key := range keys { + if v, ok := s.values[key]; ok { + result[key] = v + } + } + return result, nil } func (s *settingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { @@ -62,6 +71,8 @@ type defaultSubscriptionAssignerStub struct { err error } +type refreshTokenCacheStub struct{} + func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) { if input != nil { s.calls = append(s.calls, *input) @@ -72,6 +83,46 @@ func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.C return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil } +func (s *refreshTokenCacheStub) StoreRefreshToken(context.Context, string, *RefreshTokenData, time.Duration) error { + return nil +} + +func (s *refreshTokenCacheStub) GetRefreshToken(context.Context, string) (*RefreshTokenData, error) { + return nil, ErrRefreshTokenNotFound +} + +func (s *refreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { + return nil +} + +func (s *refreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error { + return nil +} + +func (s *refreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { + return nil +} + +func (s *refreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { + return nil +} + +func (s *refreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { + return nil +} + +func (s *refreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { + return nil, nil +} + +func (s *refreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *refreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { + return false, nil +} + func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) { if s.err != nil { return nil, s.err @@ -484,3 +535,109 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) { require.Equal(t, int64(12), assigner.calls[1].GroupID) require.Equal(t, 7, assigner.calls[1].ValidityDays) } + +func TestAuthService_Register_UsesEmailAuthSourceDefaultsWhenGrantEnabled(t *testing.T) { + repo := &userRepoStub{nextID: 52} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":91,"validity_days":3}]`, + SettingKeyAuthSourceDefaultEmailBalance: "12.5", + SettingKeyAuthSourceDefaultEmailConcurrency: "7", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + + _, user, err := service.Register(context.Background(), "email-defaults@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, 12.5, user.Balance) + require.Equal(t, 7, user.Concurrency) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(11), assigner.calls[0].GroupID) + require.Equal(t, 30, assigner.calls[0].ValidityDays) +} + +func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *testing.T) { + repo := &userRepoStub{nextID: 53} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`, + SettingKeyAuthSourceDefaultEmailBalance: "99", + SettingKeyAuthSourceDefaultEmailConcurrency: "88", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":32,"validity_days":9}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", + }, nil) + service.defaultSubAssigner = assigner + + _, user, err := service.Register(context.Background(), "email-global@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, 3.5, user.Balance) + require.Equal(t, 2, user.Concurrency) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(31), assigner.calls[0].GroupID) + require.Equal(t, 5, assigner.calls[0].ValidityDays) +} + +func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefaultsOnSignup(t *testing.T) { + repo := &userRepoStub{nextID: 61} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":81,"validity_days":1}]`, + SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75", + SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9", + SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`, + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + service.refreshTokenCache = &refreshTokenCacheStub{} + + tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "") + require.NoError(t, err) + require.NotNil(t, tokenPair) + require.NotNil(t, user) + require.Equal(t, int64(61), user.ID) + require.Equal(t, 21.75, user.Balance) + require.Equal(t, 9, user.Concurrency) + require.Len(t, repo.created, 1) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(22), assigner.calls[0].GroupID) + require.Equal(t, 14, assigner.calls[0].ValidityDays) +} + +func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantAgain(t *testing.T) { + existing := &User{ + ID: 88, + Email: "linuxdo-123@linuxdo-connect.invalid", + Username: "existing-linuxdo", + Role: RoleUser, + Status: StatusActive, + Balance: 4, + Concurrency: 1, + TokenVersion: 2, + } + repo := &userRepoStub{user: existing} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75", + SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9", + SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`, + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + service.refreshTokenCache = &refreshTokenCacheStub{} + + tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "") + require.NoError(t, err) + require.NotNil(t, tokenPair) + require.Equal(t, existing.ID, user.ID) + require.Equal(t, 4.0, user.Balance) + require.Equal(t, 1, user.Concurrency) + require.Empty(t, repo.created) + require.Empty(t, assigner.calls) +} From 0353c3870f938f5d57d4060232dfe9e3bc4a01ff Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 18:40:34 +0800 Subject: [PATCH 023/326] test: update user service stubs for identity summaries --- .../service/billing_cache_service_singleflight_test.go | 4 ++++ backend/internal/service/user_service_test.go | 9 ++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go index 4a8b8f03..2ebc2f04 100644 --- a/backend/internal/service/billing_cache_service_singleflight_test.go +++ b/backend/internal/service/billing_cache_service_singleflight_test.go @@ -86,6 +86,10 @@ func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, return &User{ID: id, Balance: s.balance}, nil } +func (s *balanceLoadUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { + return nil, nil +} + func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) { cache := &billingCacheMissStub{} userRepo := &balanceLoadUserRepoStub{ diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 7d63bb36..89f0362e 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -100,9 +100,12 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int return 0, nil } func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } -func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } -func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } -func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } +func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { + return nil, nil +} +func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } +func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { return nil } From d47580a144f083965b04e4612e4fd9ceb4779e35 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 18:42:28 +0800 Subject: [PATCH 024/326] test: pin email signup defaults in register tests --- backend/internal/service/auth_service_register_test.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index 901b3db3..e0dce982 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -373,7 +373,8 @@ func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) { func TestAuthService_Register_Success(t *testing.T) { repo := &userRepoStub{nextID: 5} service := newAuthService(repo, map[string]string{ - SettingKeyRegistrationEnabled: "true", + SettingKeyRegistrationEnabled: "true", + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", }, nil) token, user, err := service.Register(context.Background(), "user@test.com", "password") @@ -520,8 +521,9 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) { repo := &userRepoStub{nextID: 42} assigner := &defaultSubscriptionAssignerStub{} service := newAuthService(repo, map[string]string{ - SettingKeyRegistrationEnabled: "true", - SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`, + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", }, nil) service.defaultSubAssigner = assigner From 6a75bd77e3cbf11a12a624a474a979f684be5fb6 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 19:30:09 +0800 Subject: [PATCH 025/326] feat: add pending oauth email onboarding flow --- .../internal/handler/auth_linuxdo_oauth.go | 125 +++-- .../handler/auth_oauth_pending_flow.go | 426 +++++++++++++++++- .../handler/auth_oauth_pending_flow_test.go | 371 ++++++++++++++- backend/internal/handler/auth_oidc_oauth.go | 137 +++--- backend/internal/handler/auth_wechat_oauth.go | 33 ++ backend/internal/handler/dto/settings.go | 1 + backend/internal/handler/setting_handler.go | 1 + .../handler/setting_handler_public_test.go | 83 ++++ backend/internal/server/routes/auth.go | 48 ++ .../internal/service/auth_oauth_email_flow.go | 151 +++++++ backend/internal/service/setting_service.go | 2 + .../service/setting_service_public_test.go | 13 + backend/internal/service/settings_view.go | 1 + 13 files changed, 1273 insertions(+), 119 deletions(-) create mode 100644 backend/internal/handler/setting_handler_public_test.go create mode 100644 backend/internal/service/auth_oauth_email_flow.go diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 835e5fd8..c4ecb8fa 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -243,6 +243,18 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { if subject != "" { email = linuxDoSyntheticEmail(subject) } + identityKey := service.PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: subject, + } + upstreamClaims := map[string]any{ + "email": email, + "username": username, + "subject": subject, + "suggested_display_name": displayName, + "suggested_avatar_url": avatarURL, + } if intent == oauthIntentBindCurrentUser { targetUserID, err := h.readOAuthBindUserIDFromCookie(c, linuxDoOAuthBindUserCookieName) if err != nil { @@ -250,23 +262,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { return } if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: oauthIntentBindCurrentUser, - Identity: service.PendingAuthIdentityKey{ - ProviderType: "linuxdo", - ProviderKey: "linuxdo", - ProviderSubject: subject, - }, - TargetUserID: &targetUserID, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "suggested_display_name": displayName, - "suggested_avatar_url": avatarURL, - }, + Intent: oauthIntentBindCurrentUser, + Identity: identityKey, + TargetUserID: &targetUserID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "redirect": redirectTo, }, @@ -278,27 +280,60 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { return } + existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityKey) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if existingIdentityUser != nil { + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") + if err != nil { + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identityKey, + TargetUserID: &user.ID, + ResolvedEmail: existingIdentityUser.Email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if h.isForceEmailOnThirdPartySignup(c.Request.Context()) { + if err := h.createOAuthEmailRequiredPendingSession(c, identityKey, redirectTo, browserSessionKey, upstreamClaims); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if errors.Is(err, service.ErrOAuthInvitationRequired) { if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: "login", - Identity: service.PendingAuthIdentityKey{ - ProviderType: "linuxdo", - ProviderKey: "linuxdo", - ProviderSubject: subject, - }, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "suggested_display_name": displayName, - "suggested_avatar_url": avatarURL, - }, + Intent: "login", + Identity: identityKey, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "error": "invitation_required", "redirect": redirectTo, @@ -316,23 +351,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { } if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: "login", - Identity: service.PendingAuthIdentityKey{ - ProviderType: "linuxdo", - ProviderKey: "linuxdo", - ProviderSubject: subject, - }, - TargetUserID: &user.ID, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "suggested_display_name": displayName, - "suggested_avatar_url": avatarURL, - }, + Intent: "login", + Identity: identityKey, + TargetUserID: &user.ID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "access_token": tokenPair.AccessToken, "refresh_token": tokenPair.RefreshToken, diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 2d6c3714..99b9b406 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -46,6 +46,36 @@ type oauthAdoptionDecisionRequest struct { AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } +type bindPendingOAuthLoginRequest struct { + Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +type createPendingOAuthAccountRequest struct { + Email string `json:"email" binding:"required,email"` + VerifyCode string `json:"verify_code,omitempty"` + Password string `json:"password" binding:"required,min=6"` + InvitationCode string `json:"invitation_code,omitempty"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest { + return oauthAdoptionDecisionRequest{ + AdoptDisplayName: r.AdoptDisplayName, + AdoptAvatar: r.AdoptAvatar, + } +} + +func (r createPendingOAuthAccountRequest) adoptionDecision() oauthAdoptionDecisionRequest { + return oauthAdoptionDecisionRequest{ + AdoptDisplayName: r.AdoptDisplayName, + AdoptAvatar: r.AdoptAvatar, + } +} + func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) { if h == nil || h.authService == nil || h.authService.EntClient() == nil { return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") @@ -170,6 +200,36 @@ func readCompletionResponse(session map[string]any) (map[string]any, bool) { return result, true } +func clonePendingMap(values map[string]any) map[string]any { + if len(values) == 0 { + return map[string]any{} + } + cloned := make(map[string]any, len(values)) + for key, value := range values { + cloned[key] = value + } + return cloned +} + +func mergePendingCompletionResponse(session *dbent.PendingAuthSession, overrides map[string]any) map[string]any { + payload, _ := readCompletionResponse(session.LocalFlowState) + merged := clonePendingMap(payload) + if strings.TrimSpace(session.RedirectTo) != "" { + if _, exists := merged["redirect"]; !exists { + merged["redirect"] = session.RedirectTo + } + } + for key, value := range overrides { + if value == nil { + delete(merged, key) + continue + } + merged[key] = value + } + applySuggestedProfileToCompletionResponse(merged, session.UpstreamIdentityClaims) + return merged +} + func pendingSessionStringValue(values map[string]any, key string) string { if len(values) == 0 { return "" @@ -264,6 +324,89 @@ func (h *AuthHandler) entClient() *dbent.Client { return h.authService.EntClient() } +func (h *AuthHandler) isForceEmailOnThirdPartySignup(ctx context.Context) bool { + if h == nil || h.settingSvc == nil { + return false + } + defaults, err := h.settingSvc.GetAuthSourceDefaultSettings(ctx) + if err != nil || defaults == nil { + return false + } + return defaults.ForceEmailOnThirdPartySignup +} + +func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity service.PendingAuthIdentityKey) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + record, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(identity.ProviderSubject)), + ). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + + userEntity, err := client.User.Get(ctx, record.UserID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err) + } + return userEntity, nil +} + +func (h *AuthHandler) createOAuthEmailRequiredPendingSession( + c *gin.Context, + identity service.PendingAuthIdentityKey, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, +) error { + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identity, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "redirect": redirectTo, + "step": "email_required", + "force_email_on_signup": true, + "email_binding_required": true, + "existing_account_bindable": true, + }, + }) +} + +func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") } +func (h *AuthHandler) BindOIDCOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "oidc") } +func (h *AuthHandler) BindWeChatOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "wechat") } +func (h *AuthHandler) BindPendingOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "") } + +func (h *AuthHandler) CreateLinuxDoOAuthAccount(c *gin.Context) { + h.createPendingOAuthAccount(c, "linuxdo") +} + +func (h *AuthHandler) CreateOIDCOAuthAccount(c *gin.Context) { h.createPendingOAuthAccount(c, "oidc") } + +func (h *AuthHandler) CreateWeChatOAuthAccount(c *gin.Context) { + h.createPendingOAuthAccount(c, "wechat") +} + +func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) { + h.createPendingOAuthAccount(c, "") +} + func (h *AuthHandler) upsertPendingOAuthAdoptionDecision( c *gin.Context, sessionID int64, @@ -313,6 +456,60 @@ func (h *AuthHandler) upsertPendingOAuthAdoptionDecision( return decision, nil } +func (h *AuthHandler) ensurePendingOAuthAdoptionDecision( + c *gin.Context, + sessionID int64, + req oauthAdoptionDecisionRequest, +) (*dbent.IdentityAdoptionDecision, error) { + decision, err := h.upsertPendingOAuthAdoptionDecision(c, sessionID, req) + if err != nil { + return nil, err + } + if decision != nil { + return decision, nil + } + + svc, err := h.pendingIdentityService() + if err != nil { + return nil, err + } + decision, err = svc.UpsertAdoptionDecision(c.Request.Context(), service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: sessionID, + }) + if err != nil { + return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err) + } + return decision, nil +} + +func updatePendingOAuthSessionProgress( + ctx context.Context, + client *dbent.Client, + session *dbent.PendingAuthSession, + intent string, + resolvedEmail string, + targetUserID *int64, + completionResponse map[string]any, +) (*dbent.PendingAuthSession, error) { + if client == nil || session == nil { + return nil, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid") + } + + localFlowState := clonePendingMap(session.LocalFlowState) + localFlowState[oauthCompletionResponseKey] = clonePendingMap(completionResponse) + + update := client.PendingAuthSession.UpdateOneID(session.ID). + SetIntent(strings.TrimSpace(intent)). + SetResolvedEmail(strings.TrimSpace(resolvedEmail)). + SetLocalFlowState(localFlowState) + if targetUserID != nil && *targetUserID > 0 { + update = update.SetTargetUserID(*targetUserID) + } else { + update = update.ClearTargetUserID() + } + return update.Save(ctx) +} + func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) { if session == nil { return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid") @@ -401,17 +598,18 @@ func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision return decision.AdoptDisplayName || decision.AdoptAvatar } -func applyPendingOAuthAdoption( +func applyPendingOAuthBinding( ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision, overrideUserID *int64, + forceBind bool, ) error { - if client == nil || session == nil || decision == nil { + if client == nil || session == nil { return nil } - if !shouldBindPendingOAuthIdentity(session, decision) { + if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) { return nil } @@ -427,11 +625,11 @@ func applyPendingOAuthAdoption( } adoptedDisplayName := "" - if decision.AdoptDisplayName { + if decision != nil && decision.AdoptDisplayName { adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name")) } adoptedAvatarURL := "" - if decision.AdoptAvatar { + if decision != nil && decision.AdoptAvatar { adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url") } @@ -441,7 +639,7 @@ func applyPendingOAuthAdoption( } defer func() { _ = tx.Rollback() }() - if decision.AdoptDisplayName && adoptedDisplayName != "" { + if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" { if err := tx.Client().User.UpdateOneID(targetUserID). SetUsername(adoptedDisplayName). Exec(ctx); err != nil { @@ -458,10 +656,10 @@ func applyPendingOAuthAdoption( for key, value := range session.UpstreamIdentityClaims { metadata[key] = value } - if decision.AdoptDisplayName && adoptedDisplayName != "" { + if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" { metadata["display_name"] = adoptedDisplayName } - if decision.AdoptAvatar && adoptedAvatarURL != "" { + if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" { metadata["avatar_url"] = adoptedAvatarURL } @@ -473,7 +671,7 @@ func applyPendingOAuthAdoption( return err } - if decision.IdentityID == nil || *decision.IdentityID != identity.ID { + if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) { if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID). SetIdentityID(identity.ID). Save(ctx); err != nil { @@ -484,6 +682,16 @@ func applyPendingOAuthAdoption( return tx.Commit() } +func applyPendingOAuthAdoption( + ctx context.Context, + client *dbent.Client, + session *dbent.PendingAuthSession, + decision *dbent.IdentityAdoptionDecision, + overrideUserID *int64, +) error { + return applyPendingOAuthBinding(ctx, client, session, decision, overrideUserID, false) +} + func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) { if len(payload) == 0 || len(upstream) == 0 { return @@ -507,6 +715,206 @@ func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream } } +func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.AuthPendingIdentityService, *dbent.PendingAuthSession, func(), error) { + secureCookie := isRequestHTTPS(c) + clearCookies := func() { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + } + + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil || strings.TrimSpace(sessionToken) == "" { + clearCookies() + return nil, nil, clearCookies, service.ErrPendingAuthSessionNotFound + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil || strings.TrimSpace(browserSessionKey) == "" { + clearCookies() + return nil, nil, clearCookies, service.ErrPendingAuthBrowserMismatch + } + + svc, err := h.pendingIdentityService() + if err != nil { + clearCookies() + return nil, nil, clearCookies, err + } + + session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearCookies() + return nil, nil, clearCookies, err + } + + return svc, session, clearCookies, nil +} + +func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H { + payload := gin.H{ + "auth_result": "pending_session", + "provider": strings.TrimSpace(session.ProviderType), + "intent": strings.TrimSpace(session.Intent), + } + for key, value := range mergePendingCompletionResponse(session, nil) { + payload[key] = value + } + if email := strings.TrimSpace(session.ResolvedEmail); email != "" { + payload["email"] = email + } + return payload +} + +func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) { + c.JSON(http.StatusOK, gin.H{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + }) +} + +func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) { + var req bindPendingOAuthLoginRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h) + if err != nil { + response.ErrorFrom(c, err) + return + } + if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) { + response.BadRequest(c, "Pending oauth session provider mismatch") + return + } + + user, err := h.authService.ValidatePasswordCredentials(c.Request.Context(), strings.TrimSpace(req.Email), req.Password) + if err != nil { + response.ErrorFrom(c, err) + return + } + if session.TargetUserID != nil && *session.TargetUserID > 0 && user.ID != *session.TargetUserID { + response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user")) + return + } + + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), session, decision, &user.ID, true); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) + return + } + + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "") + if err != nil { + response.InternalError(c, "Failed to generate token pair") + return + } + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + clearCookies() + writeOAuthTokenPairResponse(c, tokenPair) +} + +func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) { + var req createPendingOAuthAccountRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h) + if err != nil { + response.ErrorFrom(c, err) + return + } + if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) { + response.BadRequest(c, "Pending oauth session provider mismatch") + return + } + + client := h.entClient() + if client == nil { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")) + return + } + + email := strings.TrimSpace(strings.ToLower(req.Email)) + existingUser, err := client.User.Query().Where(dbuser.EmailEQ(email)).Only(c.Request.Context()) + if err != nil && !dbent.IsNotFound(err) { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")) + return + } + if existingUser != nil { + completionResponse := mergePendingCompletionResponse(session, map[string]any{ + "step": "bind_login_required", + "email": email, + }) + session, err = updatePendingOAuthSessionProgress( + c.Request.Context(), + client, + session, + "adopt_existing_user_by_email", + email, + &existingUser.ID, + completionResponse, + ) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)) + return + } + + if _, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()); err != nil { + response.ErrorFrom(c, err) + return + } + + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session)) + return + } + + tokenPair, user, err := h.authService.RegisterOAuthEmailAccount( + c.Request.Context(), + email, + req.Password, + strings.TrimSpace(req.VerifyCode), + strings.TrimSpace(req.InvitationCode), + strings.TrimSpace(session.ProviderType), + ) + if err != nil { + response.ErrorFrom(c, err) + return + } + + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthBinding(c.Request.Context(), client, session, decision, &user.ID, true); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) + return + } + + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + clearCookies() + writeOAuthTokenPairResponse(c, tokenPair) +} + // ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload. // POST /api/v1/auth/oauth/pending/exchange func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 3afb4fb7..80338b8a 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -509,9 +509,305 @@ func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecis require.Nil(t, storedSession.ConsumedAt) } +func TestCreateOIDCOAuthAccountCreatesUserBindsIdentityAndConsumesSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "fresh@example.com", "246810") + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("create-account-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-create-123"). + SetBrowserSessionKey("create-account-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Fresh OIDC User", + "suggested_avatar_url": "https://cdn.example/fresh.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.NotEmpty(t, payload["access_token"]) + require.NotEmpty(t, payload["refresh_token"]) + require.Equal(t, "Bearer", payload["token_type"]) + + createdUser, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Only(ctx) + require.NoError(t, err) + require.Equal(t, service.StatusActive, createdUser.Status) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-create-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, createdUser.ID, identity.UserID) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestCreateOIDCOAuthAccountExistingEmailReturnsAdoptExistingUserByEmailState(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("existing-email-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-existing-123"). + SetBrowserSessionKey("existing-email-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Existing OIDC User", + "suggested_avatar_url": "https://cdn.example/existing.png", + }). + SetRedirectTo("/dashboard"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.Equal(t, "pending_session", payload["auth_result"]) + require.Equal(t, "adopt_existing_user_by_email", payload["intent"]) + require.Equal(t, "oidc", payload["provider"]) + require.Equal(t, "/dashboard", payload["redirect"]) + require.Equal(t, true, payload["adoption_required"]) + require.Equal(t, "Existing OIDC User", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/existing.png", payload["suggested_avatar_url"]) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Equal(t, "adopt_existing_user_by_email", storedSession.Intent) + require.NotNil(t, storedSession.TargetUserID) + require.Equal(t, existingUser.ID, *storedSession.TargetUserID) + require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) + require.Nil(t, storedSession.ConsumedAt) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-existing-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) +} + +func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("bind-login-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.NotEmpty(t, payload["access_token"]) + require.NotEmpty(t, payload["refresh_token"]) + require.Equal(t, "Bearer", payload["token_type"]) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-bind-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, existingUser.ID, identity.UserID) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-invalid-password-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-invalid-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("bind-login-invalid-password-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"wrong-password"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-invalid-password-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusUnauthorized, recorder.Code) + payload := decodeJSONBody(t, recorder) + require.Equal(t, "INVALID_CREDENTIALS", payload["reason"]) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-bind-invalid-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { t.Helper() + return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, false, nil) +} + +func newOAuthPendingFlowTestHandlerWithEmailVerification( + t *testing.T, + invitationEnabled bool, + email string, + code string, +) (*AuthHandler, *dbent.Client) { + t.Helper() + + cache := &oauthPendingFlowEmailCacheStub{ + verificationCodes: map[string]*service.VerificationCodeData{ + email: { + Code: code, + Attempts: 0, + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + }, + } + return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, true, cache) +} + +func newOAuthPendingFlowTestHandlerWithOptions( + t *testing.T, + invitationEnabled bool, + emailVerifyEnabled bool, + emailCache service.EmailCache, +) (*AuthHandler, *dbent.Client) { + t.Helper() + db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared") require.NoError(t, err) t.Cleanup(func() { _ = db.Close() }) @@ -538,9 +834,18 @@ func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*Auth values: map[string]string{ service.SettingKeyRegistrationEnabled: "true", service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), + service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled), }, }, cfg) userRepo := &oauthPendingFlowUserRepo{client: client} + var emailService *service.EmailService + if emailCache != nil { + emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{ + values: map[string]string{ + service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled), + }, + }, emailCache) + } authSvc := service.NewAuthService( client, userRepo, @@ -548,7 +853,7 @@ func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*Auth &oauthPendingFlowRefreshTokenCacheStub{}, cfg, settingSvc, - nil, + emailService, nil, nil, nil, @@ -622,6 +927,70 @@ func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error type oauthPendingFlowRefreshTokenCacheStub struct{} +type oauthPendingFlowEmailCacheStub struct { + verificationCodes map[string]*service.VerificationCodeData +} + +func (s *oauthPendingFlowEmailCacheStub) GetVerificationCode(_ context.Context, email string) (*service.VerificationCodeData, error) { + if s == nil || s.verificationCodes == nil { + return nil, nil + } + return s.verificationCodes[email], nil +} + +func (s *oauthPendingFlowEmailCacheStub) SetVerificationCode(_ context.Context, email string, data *service.VerificationCodeData, _ time.Duration) error { + if s.verificationCodes == nil { + s.verificationCodes = map[string]*service.VerificationCodeData{} + } + s.verificationCodes[email] = data + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) DeleteVerificationCode(_ context.Context, email string) error { + delete(s.verificationCodes, email) + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) { + return nil, nil +} + +func (s *oauthPendingFlowEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) { + return nil, nil +} + +func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) DeletePasswordResetToken(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool { + return false +} + +func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) { + return 0, nil +} + +func (s *oauthPendingFlowEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) { + return 0, nil +} + func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error { return nil } diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 0f79759e..909d6379 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -342,6 +342,21 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { idClaims.Name, oidcFallbackUsername(subject), ) + identityRef := service.PendingAuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: issuer, + ProviderSubject: subject, + } + upstreamClaims := map[string]any{ + "email": email, + "username": username, + "subject": subject, + "issuer": issuer, + "email_verified": emailVerified != nil && *emailVerified, + "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), + "suggested_avatar_url": userInfoClaims.AvatarURL, + } if intent == oauthIntentBindCurrentUser { targetUserID, err := h.readOAuthBindUserIDFromCookie(c, oidcOAuthBindUserCookieName) if err != nil { @@ -349,26 +364,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { return } if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: oauthIntentBindCurrentUser, - Identity: service.PendingAuthIdentityKey{ - ProviderType: "oidc", - ProviderKey: issuer, - ProviderSubject: subject, - }, - TargetUserID: &targetUserID, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "issuer": issuer, - "email_verified": emailVerified != nil && *emailVerified, - "provider_fallback": strings.TrimSpace(cfg.ProviderName), - "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), - "suggested_avatar_url": userInfoClaims.AvatarURL, - }, + Intent: oauthIntentBindCurrentUser, + Identity: identityRef, + TargetUserID: &targetUserID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "redirect": redirectTo, }, @@ -380,30 +382,60 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { return } + existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if existingIdentityUser != nil { + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") + if err != nil { + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identityRef, + TargetUserID: &user.ID, + ResolvedEmail: existingIdentityUser.Email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if h.isForceEmailOnThirdPartySignup(c.Request.Context()) { + if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if errors.Is(err, service.ErrOAuthInvitationRequired) { if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: "login", - Identity: service.PendingAuthIdentityKey{ - ProviderType: "oidc", - ProviderKey: issuer, - ProviderSubject: subject, - }, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "issuer": issuer, - "email_verified": emailVerified != nil && *emailVerified, - "provider_fallback": strings.TrimSpace(cfg.ProviderName), - "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), - "suggested_avatar_url": userInfoClaims.AvatarURL, - }, + Intent: "login", + Identity: identityRef, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "error": "invitation_required", "redirect": redirectTo, @@ -420,26 +452,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { } if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: "login", - Identity: service.PendingAuthIdentityKey{ - ProviderType: "oidc", - ProviderKey: issuer, - ProviderSubject: subject, - }, - TargetUserID: &user.ID, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "issuer": issuer, - "email_verified": emailVerified != nil && *emailVerified, - "provider_fallback": strings.TrimSpace(cfg.ProviderName), - "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), - "suggested_avatar_url": userInfoClaims.AvatarURL, - }, + Intent: "login", + Identity: identityRef, + TargetUserID: &user.ID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "access_token": tokenPair.AccessToken, "refresh_token": tokenPair.RefreshToken, diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 45ac6cad..6d37c799 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -214,6 +214,11 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { "suggested_display_name": strings.TrimSpace(userInfo.Nickname), "suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL), } + identityRef := service.PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: wechatOAuthProviderKey, + ProviderSubject: providerSubject, + } normalizedIntent := normalizeWeChatOAuthIntent(intent) if normalizedIntent == wechatOAuthIntentBind { @@ -232,6 +237,34 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { return } + existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if existingIdentityUser != nil { + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") + if err != nil { + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil, &user.ID); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if h.isForceEmailOnThirdPartySignup(c.Request.Context()) { + if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, err, nil); err != nil { diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index f44b3e3b..637e317b 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -167,6 +167,7 @@ type DefaultSubscriptionSetting struct { type PublicSettings struct { RegistrationEnabled bool `json:"registration_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"` + ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"` RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` PromoCodeEnabled bool `json:"promo_code_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"` diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index c7bc3e2a..9925f066 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { response.Success(c, dto.PublicSettings{ RegistrationEnabled: settings.RegistrationEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled, + ForceEmailOnThirdPartySignup: settings.ForceEmailOnThirdPartySignup, RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, PromoCodeEnabled: settings.PromoCodeEnabled, PasswordResetEnabled: settings.PasswordResetEnabled, diff --git a/backend/internal/handler/setting_handler_public_test.go b/backend/internal/handler/setting_handler_public_test.go new file mode 100644 index 00000000..114c7245 --- /dev/null +++ b/backend/internal/handler/setting_handler_public_test.go @@ -0,0 +1,83 @@ +//go:build unit + +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type settingHandlerPublicRepoStub struct { + values map[string]string +} + +func (s *settingHandlerPublicRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *settingHandlerPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *settingHandlerPublicRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *settingHandlerPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingHandlerPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *settingHandlerPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *settingHandlerPublicRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingHandler_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &settingHandlerPublicRepoStub{ + values: map[string]string{ + service.SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + h := NewSettingHandler(service.NewSettingService(repo, &config.Config{}), "test-version") + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil) + + h.GetPublicSettings(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.True(t, resp.Data.ForceEmailOnThirdPartySignup) +} diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 7a34834d..1f28e9c3 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -72,18 +72,54 @@ func RegisterAuthRoutes( }), h.Auth.ExchangePendingOAuthCompletion, ) + auth.POST("/oauth/pending/create-account", + rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreatePendingOAuthAccount, + ) + auth.POST("/oauth/pending/bind-login", + rateLimiter.LimitWithOptions("oauth-pending-bind-login", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindPendingOAuthLogin, + ) auth.POST("/oauth/linuxdo/complete-registration", rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, }), h.Auth.CompleteLinuxDoOAuthRegistration, ) + auth.POST("/oauth/linuxdo/bind-login", + rateLimiter.LimitWithOptions("oauth-linuxdo-bind-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindLinuxDoOAuthLogin, + ) + auth.POST("/oauth/linuxdo/create-account", + rateLimiter.LimitWithOptions("oauth-linuxdo-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreateLinuxDoOAuthAccount, + ) auth.POST("/oauth/wechat/complete-registration", rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, }), h.Auth.CompleteWeChatOAuthRegistration, ) + auth.POST("/oauth/wechat/bind-login", + rateLimiter.LimitWithOptions("oauth-wechat-bind-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindWeChatOAuthLogin, + ) + auth.POST("/oauth/wechat/create-account", + rateLimiter.LimitWithOptions("oauth-wechat-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreateWeChatOAuthAccount, + ) auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart) auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback) auth.POST("/oauth/oidc/complete-registration", @@ -92,6 +128,18 @@ func RegisterAuthRoutes( }), h.Auth.CompleteOIDCOAuthRegistration, ) + auth.POST("/oauth/oidc/bind-login", + rateLimiter.LimitWithOptions("oauth-oidc-bind-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindOIDCOAuthLogin, + ) + auth.POST("/oauth/oidc/create-account", + rateLimiter.LimitWithOptions("oauth-oidc-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreateOIDCOAuthAccount, + ) } // 公开设置(无需认证) diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go new file mode 100644 index 00000000..ca3403d4 --- /dev/null +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -0,0 +1,151 @@ +package service + +import ( + "context" + "errors" + "fmt" + "strings" +) + +// VerifyOAuthEmailCode verifies the locally entered email verification code for +// third-party signup and binding flows. This is intentionally independent from +// the global registration email verification toggle. +func (s *AuthService) VerifyOAuthEmailCode(ctx context.Context, email, verifyCode string) error { + email = strings.TrimSpace(strings.ToLower(email)) + verifyCode = strings.TrimSpace(verifyCode) + + if email == "" { + return ErrEmailVerifyRequired + } + if verifyCode == "" { + return ErrEmailVerifyRequired + } + if s == nil || s.emailService == nil { + return ErrServiceUnavailable + } + return s.emailService.VerifyCode(ctx, email, verifyCode) +} + +// RegisterOAuthEmailAccount creates a local account from a third-party first +// login after the user has verified a local email address. +func (s *AuthService) RegisterOAuthEmailAccount( + ctx context.Context, + email string, + password string, + verifyCode string, + invitationCode string, + signupSource string, +) (*TokenPair, *User, error) { + if s == nil { + return nil, nil, ErrServiceUnavailable + } + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { + return nil, nil, ErrRegDisabled + } + + email = strings.TrimSpace(strings.ToLower(email)) + if isReservedEmail(email) { + return nil, nil, ErrEmailReserved + } + if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil { + return nil, nil, err + } + if err := s.VerifyOAuthEmailCode(ctx, email, verifyCode); err != nil { + return nil, nil, err + } + + var invitationRedeemCode *RedeemCode + if s.settingService.IsInvitationCodeEnabled(ctx) { + if invitationCode == "" { + return nil, nil, ErrInvitationCodeRequired + } + redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode) + if err != nil { + return nil, nil, ErrInvitationCodeInvalid + } + if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { + return nil, nil, ErrInvitationCodeInvalid + } + invitationRedeemCode = redeemCode + } + + existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) + if err != nil { + return nil, nil, ErrServiceUnavailable + } + if existsEmail { + return nil, nil, ErrEmailExists + } + + hashedPassword, err := s.HashPassword(password) + if err != nil { + return nil, nil, fmt.Errorf("hash password: %w", err) + } + + signupSource = strings.TrimSpace(strings.ToLower(signupSource)) + if signupSource == "" { + signupSource = "email" + } + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) + + user := &User{ + Email: email, + PasswordHash: hashedPassword, + Role: RoleUser, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, + Status: StatusActive, + } + + if err := s.userRepo.Create(ctx, user); err != nil { + if errors.Is(err, ErrEmailExists) { + return nil, nil, ErrEmailExists + } + return nil, nil, ErrServiceUnavailable + } + + s.postAuthUserBootstrap(ctx, user, signupSource, true) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + + if invitationRedeemCode != nil { + if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { + return nil, nil, ErrInvitationCodeInvalid + } + } + + tokenPair, err := s.GenerateTokenPair(ctx, user, "") + if err != nil { + return nil, nil, fmt.Errorf("generate token pair: %w", err) + } + return tokenPair, user, nil +} + +// ValidatePasswordCredentials checks the local password without completing the +// login flow. This is used by pending third-party account adoption flows before +// the external identity has been bound. +func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, password string) (*User, error) { + if s == nil { + return nil, ErrServiceUnavailable + } + + user, err := s.userRepo.GetByEmail(ctx, strings.TrimSpace(strings.ToLower(email))) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + return nil, ErrInvalidCredentials + } + return nil, ErrServiceUnavailable + } + if !user.IsActive() { + return nil, ErrUserNotActive + } + if !s.CheckPassword(password, user.PasswordHash) { + return nil, ErrInvalidCredentials + } + return user, nil +} + +// RecordSuccessfulLogin updates last-login activity after a non-standard login +// flow finishes with a real session. +func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) { + s.touchUserLogin(ctx, userID) +} diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index de555478..a2644fcd 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -217,6 +217,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings keys := []string{ SettingKeyRegistrationEnabled, SettingKeyEmailVerifyEnabled, + SettingKeyForceEmailOnThirdPartySignup, SettingKeyRegistrationEmailSuffixWhitelist, SettingKeyPromoCodeEnabled, SettingKeyPasswordResetEnabled, @@ -294,6 +295,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings return &PublicSettings{ RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", EmailVerifyEnabled: emailVerifyEnabled, + ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true", RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist, PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 PasswordResetEnabled: passwordResetEnabled, diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go index 5cf1e860..bb97c2aa 100644 --- a/backend/internal/service/setting_service_public_test.go +++ b/backend/internal/service/setting_service_public_test.go @@ -77,3 +77,16 @@ func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) require.Equal(t, 50, settings.TableDefaultPageSize) require.Equal(t, []int{20, 50, 100}, settings.TablePageSizeOptions) } + +func TestSettingService_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) { + repo := &settingPublicRepoStub{ + values: map[string]string{ + SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetPublicSettings(context.Background()) + require.NoError(t, err) + require.True(t, settings.ForceEmailOnThirdPartySignup) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index e991ebef..72db4e31 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -128,6 +128,7 @@ type DefaultSubscriptionSetting struct { type PublicSettings struct { RegistrationEnabled bool EmailVerifyEnabled bool + ForceEmailOnThirdPartySignup bool RegistrationEmailSuffixWhitelist []string PromoCodeEnabled bool PasswordResetEnabled bool From 6ea3f42e2f825f165c98a63fa6b5f472f6853b33 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 19:30:19 +0800 Subject: [PATCH 026/326] feat: add oauth callback email binding ui --- .../api/__tests__/auth-oauth-adoption.spec.ts | 91 +++++++ frontend/src/api/auth.ts | 102 +++++-- frontend/src/stores/app.ts | 1 + frontend/src/types/index.ts | 1 + .../src/views/auth/LinuxDoCallbackView.vue | 257 +++++++++++++++++- frontend/src/views/auth/OidcCallbackView.vue | 128 ++++++++- .../src/views/auth/WechatCallbackView.vue | 108 ++++++++ .../__tests__/LinuxDoCallbackView.spec.ts | 105 +++++++ .../auth/__tests__/OidcCallbackView.spec.ts | 71 +++++ .../auth/__tests__/WechatCallbackView.spec.ts | 88 ++++++ 10 files changed, 916 insertions(+), 36 deletions(-) diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts index 9c0b4d55..f95332fb 100644 --- a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts +++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts @@ -30,6 +30,20 @@ describe('oauth adoption auth api', () => { }) }) + it('posts bind-login decisions when finalizing pending oauth bind flow', async () => { + const { completePendingOAuthBindLogin } = await import('@/api/auth') + + await completePendingOAuthBindLogin({ + adoptDisplayName: true, + adoptAvatar: false + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', { + adopt_display_name: true, + adopt_avatar: false + }) + }) + it('posts linuxdo invitation completion with adoption decisions', async () => { const { completeLinuxDoOAuthRegistration } = await import('@/api/auth') @@ -45,6 +59,21 @@ describe('oauth adoption auth api', () => { }) }) + it('posts linuxdo create-account completion with adoption decisions', async () => { + const { createPendingLinuxDoOAuthAccount } = await import('@/api/auth') + + await createPendingLinuxDoOAuthAccount('invite-code', { + adoptDisplayName: false, + adoptAvatar: true + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: false, + adopt_avatar: true + }) + }) + it('posts oidc invitation completion with adoption decisions', async () => { const { completeOIDCOAuthRegistration } = await import('@/api/auth') @@ -60,6 +89,21 @@ describe('oauth adoption auth api', () => { }) }) + it('posts oidc create-account completion with adoption decisions', async () => { + const { createPendingOIDCOAuthAccount } = await import('@/api/auth') + + await createPendingOIDCOAuthAccount('invite-code', { + adoptDisplayName: true, + adoptAvatar: false + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: true, + adopt_avatar: false + }) + }) + it('posts wechat invitation completion with adoption decisions', async () => { const { completeWeChatOAuthRegistration } = await import('@/api/auth') @@ -75,6 +119,21 @@ describe('oauth adoption auth api', () => { }) }) + it('posts wechat create-account completion with adoption decisions', async () => { + const { createPendingWeChatOAuthAccount } = await import('@/api/auth') + + await createPendingWeChatOAuthAccount('invite-code', { + adoptDisplayName: false, + adoptAvatar: false + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: false, + adopt_avatar: false + }) + }) + it('classifies oauth completion results as login or bind', async () => { const { getOAuthCompletionKind } = await import('@/api/auth') @@ -82,6 +141,38 @@ describe('oauth adoption auth api', () => { expect(getOAuthCompletionKind({ redirect: '/profile' })).toBe('bind') }) + it('provides bind-login utility helpers for invitation and suggested profile states', async () => { + const { + getPendingOAuthBindLoginKind, + hasPendingOAuthSuggestedProfile, + isPendingOAuthCreateAccountRequired + } = await import('@/api/auth') + + expect(getPendingOAuthBindLoginKind({ access_token: 'access-token' })).toBe('login') + expect(getPendingOAuthBindLoginKind({ redirect: '/profile' })).toBe('bind') + expect( + isPendingOAuthCreateAccountRequired({ + error: 'invitation_required' + }) + ).toBe(true) + expect( + isPendingOAuthCreateAccountRequired({ + error: 'other' + }) + ).toBe(false) + expect( + hasPendingOAuthSuggestedProfile({ + suggested_display_name: 'OAuth Nick' + }) + ).toBe(true) + expect( + hasPendingOAuthSuggestedProfile({ + suggested_avatar_url: 'https://cdn.example/avatar.png' + }) + ).toBe(true) + expect(hasPendingOAuthSuggestedProfile({})).toBe(false) + }) + it('prepares an oauth bind access token cookie before redirect binding', async () => { localStorage.setItem('auth_token', 'access-token-value') const setCookie = vi.fn() diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index c11bd90b..98b20154 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -193,7 +193,7 @@ export interface OAuthTokenResponse { token_type?: string } -export interface PendingOAuthExchangeResponse extends Partial { +export interface PendingOAuthBindLoginResponse extends Partial { redirect?: string error?: string adoption_required?: boolean @@ -201,6 +201,10 @@ export interface PendingOAuthExchangeResponse extends Partial +): boolean { + return completion.error === 'invitation_required' +} + +export function hasPendingOAuthSuggestedProfile( + completion: Pick< + PendingOAuthBindLoginResponse, + 'suggested_display_name' | 'suggested_avatar_url' + > +): boolean { + return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) +} + export function persistOAuthTokenContext(tokens: Partial): void { if (tokens.refresh_token) { setRefreshToken(tokens.refresh_token) @@ -431,11 +456,7 @@ export async function completeLinuxDoOAuthRegistration( invitationCode: string, decision?: OAuthAdoptionDecision ): Promise { - const { data } = await apiClient.post('/auth/oauth/linuxdo/complete-registration', { - invitation_code: invitationCode, - ...serializeOAuthAdoptionDecision(decision) - }) - return data + return createPendingLinuxDoOAuthAccount(invitationCode, decision) } /** @@ -447,32 +468,66 @@ export async function completeOIDCOAuthRegistration( invitationCode: string, decision?: OAuthAdoptionDecision ): Promise { - const { data } = await apiClient.post('/auth/oauth/oidc/complete-registration', { - invitation_code: invitationCode, - ...serializeOAuthAdoptionDecision(decision) - }) - return data + return createPendingOIDCOAuthAccount(invitationCode, decision) } export async function completeWeChatOAuthRegistration( invitationCode: string, decision?: OAuthAdoptionDecision ): Promise { - const { data } = await apiClient.post('/auth/oauth/wechat/complete-registration', { - invitation_code: invitationCode, - ...serializeOAuthAdoptionDecision(decision) - }) + return createPendingWeChatOAuthAccount(invitationCode, decision) +} + +async function createPendingOAuthAccount( + provider: 'linuxdo' | 'oidc' | 'wechat', + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + const { data } = await apiClient.post( + `/auth/oauth/${provider}/complete-registration`, + { + invitation_code: invitationCode, + ...serializeOAuthAdoptionDecision(decision) + } + ) + return data +} + +export async function createPendingLinuxDoOAuthAccount( + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + return createPendingOAuthAccount('linuxdo', invitationCode, decision) +} + +export async function createPendingOIDCOAuthAccount( + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + return createPendingOAuthAccount('oidc', invitationCode, decision) +} + +export async function createPendingWeChatOAuthAccount( + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + return createPendingOAuthAccount('wechat', invitationCode, decision) +} + +export async function completePendingOAuthBindLogin( + decision?: OAuthAdoptionDecision +): Promise { + const { data } = await apiClient.post( + '/auth/oauth/pending/exchange', + serializeOAuthAdoptionDecision(decision) + ) return data } export async function exchangePendingOAuthCompletion( decision?: OAuthAdoptionDecision ): Promise { - const { data } = await apiClient.post( - '/auth/oauth/pending/exchange', - serializeOAuthAdoptionDecision(decision) - ) - return data + return completePendingOAuthBindLogin(decision) } export const authAPI = { @@ -498,6 +553,13 @@ export const authAPI = { resetPassword, refreshToken, revokeAllSessions, + getPendingOAuthBindLoginKind, + isPendingOAuthCreateAccountRequired, + hasPendingOAuthSuggestedProfile, + completePendingOAuthBindLogin, + createPendingLinuxDoOAuthAccount, + createPendingOIDCOAuthAccount, + createPendingWeChatOAuthAccount, exchangePendingOAuthCompletion, completeLinuxDoOAuthRegistration, completeOIDCOAuthRegistration, diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts index 1b1af87b..a8e03a51 100644 --- a/frontend/src/stores/app.ts +++ b/frontend/src/stores/app.ts @@ -316,6 +316,7 @@ export const useAppStore = defineStore('app', () => { return { registration_enabled: false, email_verify_enabled: false, + force_email_on_third_party_signup: false, registration_email_suffix_whitelist: [], promo_code_enabled: true, password_reset_enabled: false, diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index a19d6c26..a4b2277b 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -142,6 +142,7 @@ export interface CustomEndpoint { export interface PublicSettings { registration_enabled: boolean email_verify_enabled: boolean + force_email_on_third_party_signup: boolean registration_email_suffix_whitelist: string[] promo_code_enabled: boolean password_reset_enabled: boolean diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue index 6dc8f242..00b73868 100644 --- a/frontend/src/views/auth/LinuxDoCallbackView.vue +++ b/frontend/src/views/auth/LinuxDoCallbackView.vue @@ -11,7 +11,10 @@
-
+
+ + + +
@@ -127,11 +214,12 @@ diff --git a/frontend/src/components/admin/monitor/MonitorFiltersBar.vue b/frontend/src/components/admin/monitor/MonitorFiltersBar.vue new file mode 100644 index 00000000..ebb06a68 --- /dev/null +++ b/frontend/src/components/admin/monitor/MonitorFiltersBar.vue @@ -0,0 +1,95 @@ + + + diff --git a/frontend/src/components/admin/monitor/MonitorFormDialog.vue b/frontend/src/components/admin/monitor/MonitorFormDialog.vue new file mode 100644 index 00000000..920c3f79 --- /dev/null +++ b/frontend/src/components/admin/monitor/MonitorFormDialog.vue @@ -0,0 +1,297 @@ + + + diff --git a/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue b/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue new file mode 100644 index 00000000..eefe4073 --- /dev/null +++ b/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue @@ -0,0 +1,64 @@ + + + diff --git a/frontend/src/components/admin/monitor/MonitorPrimaryModelCell.vue b/frontend/src/components/admin/monitor/MonitorPrimaryModelCell.vue new file mode 100644 index 00000000..eccec828 --- /dev/null +++ b/frontend/src/components/admin/monitor/MonitorPrimaryModelCell.vue @@ -0,0 +1,71 @@ + + + diff --git a/frontend/src/components/admin/monitor/MonitorRunResultDialog.vue b/frontend/src/components/admin/monitor/MonitorRunResultDialog.vue new file mode 100644 index 00000000..02fa6e8d --- /dev/null +++ b/frontend/src/components/admin/monitor/MonitorRunResultDialog.vue @@ -0,0 +1,56 @@ + + + diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue index 92dcc519..23d0f4e9 100644 --- a/frontend/src/components/layout/AppSidebar.vue +++ b/frontend/src/components/layout/AppSidebar.vue @@ -38,7 +38,7 @@ 'sidebar-link-collapsed': sidebarCollapsed }" :title="sidebarCollapsed ? item.label : undefined" - @click="sidebarCollapsed ? undefined : toggleGroup(item)" + @click="handleGroupClick(item)" > import { computed, h, onMounted, ref, watch } from 'vue' -import { useRoute } from 'vue-router' +import { useRoute, useRouter } from 'vue-router' import { useI18n } from 'vue-i18n' import { useAdminSettingsStore, useAppStore, useAuthStore, useOnboardingStore } from '@/stores' import VersionBadge from '@/components/common/VersionBadge.vue' @@ -194,11 +194,17 @@ interface NavItem { iconSvg?: string hideInSimpleMode?: boolean children?: NavItem[] + /** + * When true, the parent item only toggles the expand/collapse state and + * does NOT navigate to its `path`. The `path` is purely a stable key. + */ + expandOnly?: boolean } const { t } = useI18n() const route = useRoute() +const router = useRouter() const appStore = useAppStore() const authStore = useAuthStore() const onboardingStore = useOnboardingStore() @@ -549,6 +555,41 @@ const ChevronDoubleRightIcon = { ) } +const SignalIcon = { + render: () => + h( + 'svg', + { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' }, + [ + h('path', { + 'stroke-linecap': 'round', + 'stroke-linejoin': 'round', + d: 'M9.348 14.651a3.75 3.75 0 010-5.303m5.304 0a3.75 3.75 0 010 5.303m-7.425 2.122a6.75 6.75 0 010-9.546m9.546 0a6.75 6.75 0 010 9.546M5.106 18.894c-3.808-3.807-3.808-9.98 0-13.788m13.788 0c3.808 3.807 3.808 9.98 0 13.788M12 12h.008v.008H12V12zm.375 0a.375.375 0 11-.75 0 .375.375 0 01.75 0z' + }) + ] + ) +} + +const PriceTagIcon = { + render: () => + h( + 'svg', + { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' }, + [ + h('path', { + 'stroke-linecap': 'round', + 'stroke-linejoin': 'round', + d: 'M9.568 3H5.25A2.25 2.25 0 003 5.25v4.318c0 .597.237 1.17.659 1.591l9.581 9.581c.699.699 1.78.872 2.607.33a18.095 18.095 0 005.223-5.223c.542-.827.369-1.908-.33-2.607L11.16 3.66A2.25 2.25 0 009.568 3z' + }), + h('path', { + 'stroke-linecap': 'round', + 'stroke-linejoin': 'round', + d: 'M6 6h.008v.008H6V6z' + }) + ] + ) +} + const ChevronDownIcon = { render: () => h( @@ -570,6 +611,7 @@ const userNavItems = computed((): NavItem[] => { { path: '/dashboard', label: t('nav.dashboard'), icon: DashboardIcon }, { path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon }, { path: '/usage', label: t('nav.usage'), icon: ChartIcon, hideInSimpleMode: true }, + { path: '/monitor', label: t('nav.channelStatus'), icon: SignalIcon }, { path: '/subscriptions', label: t('nav.mySubscriptions'), icon: CreditCardIcon, hideInSimpleMode: true }, ...(appStore.cachedPublicSettings?.payment_enabled ? [ @@ -608,6 +650,7 @@ const personalNavItems = computed((): NavItem[] => { const items: NavItem[] = [ { path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon }, { path: '/usage', label: t('nav.usage'), icon: ChartIcon, hideInSimpleMode: true }, + { path: '/monitor', label: t('nav.channelStatus'), icon: SignalIcon }, { path: '/subscriptions', label: t('nav.mySubscriptions'), icon: CreditCardIcon, hideInSimpleMode: true }, ...(appStore.cachedPublicSettings?.payment_enabled ? [ @@ -664,7 +707,17 @@ const adminNavItems = computed((): NavItem[] => { : []), { path: '/admin/users', label: t('nav.users'), icon: UsersIcon, hideInSimpleMode: true }, { path: '/admin/groups', label: t('nav.groups'), icon: FolderIcon, hideInSimpleMode: true }, - { path: '/admin/channels', label: t('nav.channels', '渠道管理'), icon: ChannelIcon, hideInSimpleMode: true }, + { + path: '/admin/channels', + label: t('nav.channelManagement'), + icon: ChannelIcon, + hideInSimpleMode: true, + expandOnly: true, + children: [ + { path: '/admin/channels/pricing', label: t('nav.channelPricing'), icon: PriceTagIcon }, + { path: '/admin/channels/monitor', label: t('nav.channelMonitor'), icon: SignalIcon }, + ], + }, { path: '/admin/subscriptions', label: t('nav.subscriptions'), icon: CreditCardIcon, hideInSimpleMode: true }, { path: '/admin/accounts', label: t('nav.accounts'), icon: GlobeIcon }, { path: '/admin/announcements', label: t('nav.announcements'), icon: BellIcon }, @@ -678,6 +731,7 @@ const adminNavItems = computed((): NavItem[] => { label: t('nav.orderManagement'), icon: OrderIcon, hideInSimpleMode: true, + expandOnly: true, children: [ { path: '/admin/orders/dashboard', label: t('nav.paymentDashboard'), icon: ChartIcon }, { path: '/admin/orders', label: t('nav.orderManagement'), icon: OrderIcon }, @@ -764,6 +818,28 @@ function toggleGroup(item: NavItem) { } } +/** + * Click handler for collapsible parent items. + * - When sidebar is collapsed: do nothing (children are not visible). + * - When `expandOnly` is true: only toggle expand state. + * - Otherwise (default, e.g. /admin/orders): navigate to the parent path + * (router-link semantics) and ensure the group is expanded. + */ +function handleGroupClick(item: NavItem) { + if (sidebarCollapsed.value) return + if (item.expandOnly) { + toggleGroup(item) + return + } + // Push to path and ensure expanded + if (route.path !== item.path) { + router.push(item.path) + } + if (!expandedGroups.value.has(item.path)) { + expandedGroups.value.add(item.path) + } +} + // Initialize theme const savedTheme = localStorage.getItem('theme') if ( diff --git a/frontend/src/components/user/MonitorDetailDialog.vue b/frontend/src/components/user/MonitorDetailDialog.vue new file mode 100644 index 00000000..564f461b --- /dev/null +++ b/frontend/src/components/user/MonitorDetailDialog.vue @@ -0,0 +1,114 @@ + + + diff --git a/frontend/src/components/user/MonitorPrimaryModelCell.vue b/frontend/src/components/user/MonitorPrimaryModelCell.vue new file mode 100644 index 00000000..32620b2a --- /dev/null +++ b/frontend/src/components/user/MonitorPrimaryModelCell.vue @@ -0,0 +1,71 @@ + + + diff --git a/frontend/src/composables/useChannelMonitorFormat.ts b/frontend/src/composables/useChannelMonitorFormat.ts new file mode 100644 index 00000000..fbb310fa --- /dev/null +++ b/frontend/src/composables/useChannelMonitorFormat.ts @@ -0,0 +1,97 @@ +/** + * Shared formatting helpers for channel monitor views (admin + user). + * + * Centralises: + * - status / provider label + badge class lookups + * - latency / availability / percent number formatting + * + * i18n keys live under `monitorCommon.*` so admin and user views share the + * same translation source. + */ + +import { useI18n } from 'vue-i18n' +import type { MonitorStatus, Provider } from '@/api/admin/channelMonitor' +import { + PROVIDER_OPENAI, + PROVIDER_ANTHROPIC, + PROVIDER_GEMINI, + STATUS_OPERATIONAL, + STATUS_DEGRADED, + STATUS_FAILED, + STATUS_ERROR, +} from '@/constants/channelMonitor' + +const NEUTRAL_BADGE = 'bg-gray-100 text-gray-800 dark:bg-dark-700 dark:text-gray-300' + +export interface AvailabilityRow { + primary_status: MonitorStatus | '' + availability_7d: number | null | undefined +} + +export function useChannelMonitorFormat() { + const { t } = useI18n() + + function statusLabel(s: MonitorStatus | ''): string { + if (!s) return t('monitorCommon.status.unknown') + return t(`monitorCommon.status.${s}`) + } + + function statusBadgeClass(s: MonitorStatus | ''): string { + switch (s) { + case STATUS_OPERATIONAL: + return 'bg-green-100 text-green-800 dark:bg-green-900/30 dark:text-green-300' + case STATUS_DEGRADED: + return 'bg-yellow-100 text-yellow-800 dark:bg-yellow-900/30 dark:text-yellow-300' + case STATUS_FAILED: + return 'bg-red-100 text-red-800 dark:bg-red-900/30 dark:text-red-300' + case STATUS_ERROR: + default: + return NEUTRAL_BADGE + } + } + + function providerLabel(p: Provider | string): string { + if (p === PROVIDER_OPENAI || p === PROVIDER_ANTHROPIC || p === PROVIDER_GEMINI) { + return t(`monitorCommon.providers.${p}`) + } + return p || '-' + } + + function providerBadgeClass(p: Provider | string): string { + switch (p) { + case PROVIDER_OPENAI: + return 'bg-green-100 text-green-800 dark:bg-green-900/30 dark:text-green-300' + case PROVIDER_ANTHROPIC: + return 'bg-orange-100 text-orange-800 dark:bg-orange-900/30 dark:text-orange-300' + case PROVIDER_GEMINI: + return 'bg-blue-100 text-blue-800 dark:bg-blue-900/30 dark:text-blue-300' + default: + return NEUTRAL_BADGE + } + } + + function formatLatency(ms: number | null | undefined): string { + if (ms == null) return t('monitorCommon.latencyEmpty') + return String(Math.round(ms)) + } + + function formatPercent(v: number | null | undefined): string { + if (v == null || Number.isNaN(v)) return '-' + return `${v.toFixed(2)}%` + } + + function formatAvailability(row: AvailabilityRow): string { + if (!row.primary_status) return '-' + return formatPercent(row.availability_7d) + } + + return { + statusLabel, + statusBadgeClass, + providerLabel, + providerBadgeClass, + formatLatency, + formatPercent, + formatAvailability, + } +} diff --git a/frontend/src/constants/channelMonitor.ts b/frontend/src/constants/channelMonitor.ts new file mode 100644 index 00000000..7523a878 --- /dev/null +++ b/frontend/src/constants/channelMonitor.ts @@ -0,0 +1,35 @@ +/** + * Channel monitor shared constants. + * + * Single source of truth for provider/status string values used by both the + * admin (`views/admin/ChannelMonitorView.vue`) and user-facing + * (`views/user/ChannelStatusView.vue`) screens, plus the shared composable + * `useChannelMonitorFormat`. + */ + +import type { Provider, MonitorStatus } from '@/api/admin/channelMonitor' + +export const PROVIDER_OPENAI: Provider = 'openai' +export const PROVIDER_ANTHROPIC: Provider = 'anthropic' +export const PROVIDER_GEMINI: Provider = 'gemini' + +export const PROVIDERS: readonly Provider[] = [ + PROVIDER_OPENAI, + PROVIDER_ANTHROPIC, + PROVIDER_GEMINI, +] + +export const STATUS_OPERATIONAL: MonitorStatus = 'operational' +export const STATUS_DEGRADED: MonitorStatus = 'degraded' +export const STATUS_FAILED: MonitorStatus = 'failed' +export const STATUS_ERROR: MonitorStatus = 'error' + +export const MONITOR_STATUSES: readonly MonitorStatus[] = [ + STATUS_OPERATIONAL, + STATUS_DEGRADED, + STATUS_FAILED, + STATUS_ERROR, +] + +/** Default polling interval (seconds) for new monitors. */ +export const DEFAULT_INTERVAL_SECONDS = 60 diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 1b7ffa81..32fbce19 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -245,6 +245,7 @@ export default { // Common common: { loading: 'Loading...', + submitting: 'Submitting...', justNow: 'just now', save: 'Save', saved: 'Saved successfully', @@ -363,7 +364,11 @@ export default { orderManagement: 'Orders', paymentDashboard: 'Payment Dashboard', paymentConfig: 'Payment Config', - paymentPlans: 'Plans' + paymentPlans: 'Plans', + channelManagement: 'Channels', + channelPricing: 'Channel Pricing', + channelMonitor: 'Channel Monitor', + channelStatus: 'Channel Status', }, // Auth @@ -846,6 +851,58 @@ export default { userAgent: 'User-Agent' }, + // Shared keys for channel monitor (admin + user views) + monitorCommon: { + status: { + operational: 'Operational', + degraded: 'Degraded', + failed: 'Failed', + error: 'Error', + unknown: '-' + }, + providers: { + openai: 'OpenAI', + anthropic: 'Anthropic', + gemini: 'Gemini' + }, + extraModelsHeader: 'Extra Models', + extraModelsEmpty: 'No extra models', + latencyEmpty: '-' + }, + + // Channel Status (user-facing read-only view) + channelStatus: { + title: 'Channel Status', + description: 'Inspect channel availability, latency and recent status', + searchPlaceholder: 'Search channels...', + allProviders: 'All Providers', + loadError: 'Failed to load channel status', + detailLoadError: 'Failed to load channel detail', + detailTitle: 'Channel Detail', + closeDetail: 'Close', + columns: { + name: 'Name', + provider: 'Provider', + groupName: 'Group', + primaryModel: 'Primary Model', + availability7d: '7d Availability', + latency: 'Latency (ms)' + }, + detailColumns: { + model: 'Model', + latestStatus: 'Latest Status', + latestLatency: 'Latest Latency (ms)', + availability7d: '7d Availability', + availability15d: '15d Availability', + availability30d: '30d Availability', + avgLatency7d: '7d Avg Latency (ms)' + }, + empty: { + title: 'No channels available', + description: 'No monitored channels have been configured yet.' + } + }, + // Redeem redeem: { title: 'Redeem Code', @@ -2014,6 +2071,69 @@ export default { } }, + // Channel Monitor + channelMonitor: { + title: 'Channel Monitor', + description: 'Monitor channel availability, latency and status', + searchPlaceholder: 'Search monitor name...', + allProviders: 'All Providers', + allStatus: 'All Status', + enabledFilter: 'Enabled', + onlyEnabled: 'Enabled only', + onlyDisabled: 'Disabled only', + createButton: 'Create Monitor', + createTitle: 'Create Channel Monitor', + editTitle: 'Edit Channel Monitor', + runNow: 'Run Now', + runSuccess: 'Check completed', + runFailed: 'Check failed', + apiKeyDecryptFailed: 'API Key decryption failed. Please re-edit this monitor with a fresh key.', + createSuccess: 'Monitor created', + updateSuccess: 'Monitor updated', + deleteSuccess: 'Monitor deleted', + loadError: 'Failed to load monitors', + deleteConfirm: 'Are you sure you want to delete monitor "{name}"? This action cannot be undone.', + nameRequired: 'Please enter a monitor name', + primaryModelRequired: 'Please enter a primary model', + columns: { + name: 'Name', + provider: 'Provider', + primaryModel: 'Primary Model', + availability7d: '7d Availability', + latency: 'Latency (ms)', + enabled: 'Enabled', + actions: 'Actions' + }, + form: { + name: 'Name', + namePlaceholder: 'Enter monitor name', + provider: 'Provider', + endpoint: 'Endpoint', + endpointPlaceholder: 'https://api.example.com', + useCurrentDomain: 'Use current service', + apiKey: 'API Key', + apiKeyPlaceholder: 'Enter API Key', + apiKeyEditPlaceholder: 'Leave blank to keep current key', + useMyKey: 'Use my key', + selectKeyTitle: 'Select my API Key', + selectKeyHint: 'Only your active, non-expired keys are listed.', + noActiveKey: 'No active API keys available', + primaryModel: 'Primary Model', + primaryModelPlaceholder: 'gpt-4o-mini', + extraModels: 'Extra Models', + extraModelsPlaceholder: 'Press Enter to add extra model', + groupName: 'Group Name', + groupNamePlaceholder: 'Optional, used to group rows in user view', + intervalSeconds: 'Interval (seconds)', + intervalSecondsHint: 'Range: 15 - 3600 seconds', + enabled: 'Enable monitor', + kindRequired: 'Please select a provider' + }, + runResultTitle: 'Check Result', + noMonitorsYet: 'No monitors yet', + createFirstMonitor: 'Create your first monitor to track channel availability' + }, + // Subscriptions subscriptions: { title: 'Subscription Management', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index beb6841f..dd3af363 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -245,6 +245,7 @@ export default { // Common common: { loading: '加载中...', + submitting: '提交中...', justNow: '刚刚', save: '保存', saved: '保存成功', @@ -363,7 +364,11 @@ export default { orderManagement: '订单管理', paymentDashboard: '支付概览', paymentConfig: '支付配置', - paymentPlans: '订阅套餐' + paymentPlans: '订阅套餐', + channelManagement: '渠道管理', + channelPricing: '渠道定价', + channelMonitor: '渠道监控', + channelStatus: '渠道状态', }, // Auth @@ -850,6 +855,58 @@ export default { userAgent: 'User-Agent' }, + // Shared keys for channel monitor (admin + user views) + monitorCommon: { + status: { + operational: '正常', + degraded: '降级', + failed: '失败', + error: '错误', + unknown: '-' + }, + providers: { + openai: 'OpenAI', + anthropic: 'Anthropic', + gemini: 'Gemini' + }, + extraModelsHeader: '附加模型', + extraModelsEmpty: '无附加模型', + latencyEmpty: '-' + }, + + // Channel Status (user-facing read-only view) + channelStatus: { + title: '渠道状态', + description: '查看渠道可用性、延迟和近期状态', + searchPlaceholder: '搜索渠道...', + allProviders: '全部供应商', + loadError: '加载渠道状态失败', + detailLoadError: '加载渠道详情失败', + detailTitle: '渠道详情', + closeDetail: '关闭', + columns: { + name: '名称', + provider: '供应商', + groupName: '分组', + primaryModel: '主模型', + availability7d: '7 天可用率', + latency: '延迟 (ms)' + }, + detailColumns: { + model: '模型', + latestStatus: '最新状态', + latestLatency: '最新延迟 (ms)', + availability7d: '7 天可用率', + availability15d: '15 天可用率', + availability30d: '30 天可用率', + avgLatency7d: '7 天平均延迟 (ms)' + }, + empty: { + title: '暂无可显示的渠道', + description: '管理员尚未配置可监控的渠道。' + } + }, + // Redeem redeem: { title: '兑换码', @@ -2093,6 +2150,69 @@ export default { } }, + // Channel Monitor + channelMonitor: { + title: '渠道监控', + description: '监测各渠道的可用性、延迟和状态', + searchPlaceholder: '搜索监控名称...', + allProviders: '全部供应商', + allStatus: '全部状态', + enabledFilter: '启用状态', + onlyEnabled: '仅启用', + onlyDisabled: '仅禁用', + createButton: '新增监控', + createTitle: '新增渠道监控', + editTitle: '编辑渠道监控', + runNow: '立即检测', + runSuccess: '检测完成', + runFailed: '检测失败', + apiKeyDecryptFailed: 'API Key 解密失败,请重新编辑该监控并填入新的 Key', + createSuccess: '监控创建成功', + updateSuccess: '监控更新成功', + deleteSuccess: '监控删除成功', + loadError: '加载监控列表失败', + deleteConfirm: '确定要删除监控「{name}」吗?此操作不可撤销。', + nameRequired: '请输入监控名称', + primaryModelRequired: '请输入主模型', + columns: { + name: '名称', + provider: '供应商', + primaryModel: '主模型', + availability7d: '7 天可用率', + latency: '延迟 (ms)', + enabled: '启用', + actions: '操作' + }, + form: { + name: '名称', + namePlaceholder: '输入监控名称', + provider: '供应商', + endpoint: '上游地址', + endpointPlaceholder: 'https://api.example.com', + useCurrentDomain: '使用当前服务', + apiKey: 'API Key', + apiKeyPlaceholder: '请输入 API Key', + apiKeyEditPlaceholder: '留空表示不修改', + useMyKey: '使用我的 Key', + selectKeyTitle: '选择我的 API Key', + selectKeyHint: '仅显示当前账号下处于「启用」状态且未过期的 Key。', + noActiveKey: '没有可用的启用状态 Key', + primaryModel: '主模型', + primaryModelPlaceholder: 'gpt-4o-mini', + extraModels: '附加模型', + extraModelsPlaceholder: '回车添加附加模型', + groupName: '分组名称', + groupNamePlaceholder: '可选,用于在用户视图中聚合显示', + intervalSeconds: '检测间隔 (秒)', + intervalSecondsHint: '范围:15 - 3600 秒', + enabled: '启用监控', + kindRequired: '请选择供应商' + }, + runResultTitle: '检测结果', + noMonitorsYet: '暂无监控', + createFirstMonitor: '创建第一个监控来跟踪渠道可用性' + }, + // Subscriptions Management subscriptions: { title: '订阅管理', diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index b97ccb5d..491a984d 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -360,6 +360,10 @@ const routes: RouteRecordRaw[] = [ }, { path: '/admin/channels', + redirect: '/admin/channels/pricing' + }, + { + path: '/admin/channels/pricing', name: 'AdminChannels', component: () => import('@/views/admin/ChannelsView.vue'), meta: { @@ -370,6 +374,29 @@ const routes: RouteRecordRaw[] = [ descriptionKey: 'admin.channels.description' } }, + { + path: '/admin/channels/monitor', + name: 'AdminChannelMonitor', + component: () => import('@/views/admin/ChannelMonitorView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: true, + title: 'Channel Monitor', + titleKey: 'admin.channelMonitor.title', + descriptionKey: 'admin.channelMonitor.description' + } + }, + { + path: '/monitor', + name: 'ChannelStatus', + component: () => import('@/views/user/ChannelStatusView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: false, + title: 'Channel Status', + titleKey: 'nav.channelStatus' + } + }, { path: '/admin/subscriptions', name: 'AdminSubscriptions', diff --git a/frontend/src/views/admin/ChannelMonitorView.vue b/frontend/src/views/admin/ChannelMonitorView.vue new file mode 100644 index 00000000..8f0a1e2f --- /dev/null +++ b/frontend/src/views/admin/ChannelMonitorView.vue @@ -0,0 +1,295 @@ + + + diff --git a/frontend/src/views/user/ChannelStatusView.vue b/frontend/src/views/user/ChannelStatusView.vue new file mode 100644 index 00000000..9f5fe8d1 --- /dev/null +++ b/frontend/src/views/user/ChannelStatusView.vue @@ -0,0 +1,208 @@ + + + From 58b2cc380fefc180d96c3baf10d3214026c1341c Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 20:22:00 +0800 Subject: [PATCH 032/326] test: harden payment result resume flow --- frontend/src/views/user/PaymentResultView.vue | 50 +++---- .../user/__tests__/PaymentResultView.spec.ts | 132 ++++++++++++++++++ 2 files changed, 157 insertions(+), 25 deletions(-) create mode 100644 frontend/src/views/user/__tests__/PaymentResultView.spec.ts diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue index 6431ddf6..e1db3ce2 100644 --- a/frontend/src/views/user/PaymentResultView.vue +++ b/frontend/src/views/user/PaymentResultView.vue @@ -94,6 +94,7 @@ import { ref, computed, onMounted } from 'vue' import { useI18n } from 'vue-i18n' import { useRoute, useRouter } from 'vue-router' import OrderStatusBadge from '@/components/payment/OrderStatusBadge.vue' +import { PAYMENT_RECOVERY_STORAGE_KEY, readPaymentRecoverySnapshot } from '@/components/payment/paymentFlow' import { usePaymentStore } from '@/stores/payment' import { paymentAPI } from '@/api/payment' import type { PaymentOrder } from '@/types/payment' @@ -129,46 +130,46 @@ const feeAmount = computed(() => { }) const isSuccess = computed(() => { - // Always prioritize actual order status from backend - if (order.value) { - return SUCCESS_STATUSES.has(order.value.status) - } - // Fallback only when order not loaded - if (route.query.status === 'success') return true - if (route.query.trade_status === 'TRADE_SUCCESS') return true - return false + return !!order.value && SUCCESS_STATUSES.has(order.value.status) }) -/** Extract numeric order ID from out_trade_no like "sub2_46" → 46 */ -function parseOutTradeNo(outTradeNo: string): number { - const match = outTradeNo.match(/_(\d+)$/) - return match ? Number(match[1]) : 0 -} - onMounted(async () => { - // Try order_id first (internal navigation from QRCode/Stripe pages) + const resumeToken = typeof route.query.resume_token === 'string' + ? route.query.resume_token + : '' let orderId = Number(route.query.order_id) || 0 const outTradeNo = String(route.query.out_trade_no || '') - // Fallback: EasyPay return URL with out_trade_no - if (!orderId && outTradeNo) { - orderId = parseOutTradeNo(outTradeNo) - // Store return info for display when order lookup fails + if (!orderId && resumeToken && typeof window !== 'undefined') { + const restored = readPaymentRecoverySnapshot( + window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY), + { resumeToken }, + ) + if (restored?.orderId) { + orderId = restored.orderId + } + } + + if (orderId) { + try { + order.value = await paymentStore.pollOrderStatus(orderId) + } catch (_err: unknown) { + // Order lookup failed, will try legacy fallback below when possible. + } + } + + if (!order.value && outTradeNo) { returnInfo.value = { outTradeNo, money: String(route.query.money || ''), type: String(route.query.type || ''), tradeStatus: String(route.query.trade_status || ''), } - } - // Verify payment via public endpoint (works without login) - if (outTradeNo) { try { const result = await paymentAPI.verifyOrderPublic(outTradeNo) order.value = result.data } catch (_err: unknown) { - // Public verify failed, try authenticated endpoint if logged in try { const result = await paymentAPI.verifyOrder(outTradeNo) order.value = result.data @@ -176,12 +177,11 @@ onMounted(async () => { } } - // Normal order lookup by ID (if verify didn't load the order) if (!order.value && orderId) { try { order.value = await paymentStore.pollOrderStatus(orderId) } catch (_err: unknown) { - // Order lookup failed, will show returnInfo fallback + // Order lookup failed, will show returnInfo fallback. } } loading.value = false diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts new file mode 100644 index 00000000..b06217ab --- /dev/null +++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts @@ -0,0 +1,132 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' + +const routeState = vi.hoisted(() => ({ + query: {} as Record, +})) + +const routerPush = vi.hoisted(() => vi.fn()) +const pollOrderStatus = vi.hoisted(() => vi.fn()) +const verifyOrderPublic = vi.hoisted(() => vi.fn()) +const verifyOrder = vi.hoisted(() => vi.fn()) + +vi.mock('vue-router', async () => { + const actual = await vi.importActual('vue-router') + return { + ...actual, + useRoute: () => routeState, + useRouter: () => ({ push: routerPush }), + } +}) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string) => key, + }), + } +}) + +vi.mock('@/stores/payment', () => ({ + usePaymentStore: () => ({ + pollOrderStatus, + }), +})) + +vi.mock('@/api/payment', () => ({ + paymentAPI: { + verifyOrderPublic, + verifyOrder, + }, +})) + +import PaymentResultView from '../PaymentResultView.vue' +import { PAYMENT_RECOVERY_STORAGE_KEY } from '@/components/payment/paymentFlow' + +const orderFactory = (status: string) => ({ + id: 42, + user_id: 9, + amount: 88, + pay_amount: 88, + fee_rate: 0, + payment_type: 'alipay', + out_trade_no: 'sub2_20260420abcd1234', + status, + order_type: 'balance', + created_at: '2026-04-20T12:00:00Z', + expires_at: '2026-04-20T12:30:00Z', + refund_amount: 0, +}) + +describe('PaymentResultView', () => { + beforeEach(() => { + routeState.query = {} + routerPush.mockReset() + pollOrderStatus.mockReset() + verifyOrderPublic.mockReset() + verifyOrder.mockReset() + window.localStorage.clear() + }) + + it('restores order id from a matching resume token and does not trust query success flags', async () => { + routeState.query = { + resume_token: 'resume-42', + status: 'success', + } + window.localStorage.setItem(PAYMENT_RECOVERY_STORAGE_KEY, JSON.stringify({ + orderId: 42, + amount: 88, + qrCode: '', + expiresAt: '2099-01-01T00:10:00.000Z', + paymentType: 'alipay', + payUrl: 'https://pay.example.com/session/42', + clientSecret: '', + payAmount: 88, + orderType: 'balance', + paymentMode: 'redirect', + resumeToken: 'resume-42', + createdAt: Date.UTC(2099, 0, 1, 0, 0, 0), + })) + pollOrderStatus.mockResolvedValue(orderFactory('PENDING')) + + const wrapper = mount(PaymentResultView, { + global: { + stubs: { + OrderStatusBadge: true, + }, + }, + }) + + await flushPromises() + + expect(pollOrderStatus).toHaveBeenCalledWith(42) + expect(verifyOrderPublic).not.toHaveBeenCalled() + expect(wrapper.text()).toContain('payment.result.failed') + expect(wrapper.text()).not.toContain('payment.result.success') + }) + + it('keeps legacy out_trade_no verification as a fallback when no order context is available', async () => { + routeState.query = { + out_trade_no: 'legacy-123', + trade_status: 'TRADE_SUCCESS', + } + verifyOrderPublic.mockResolvedValue({ + data: orderFactory('PAID'), + }) + + const wrapper = mount(PaymentResultView, { + global: { + stubs: { + OrderStatusBadge: true, + }, + }, + }) + + await flushPromises() + + expect(verifyOrderPublic).toHaveBeenCalledWith('legacy-123') + expect(wrapper.text()).toContain('payment.result.success') + }) +}) From 40d4e167cd8093bc3f21d4dbe205946df9c0aead Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 20 Apr 2026 20:06:53 +0800 Subject: [PATCH 033/326] feat(payment): i18n payment error codes and label localization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pairs with the backend structured payment errors (reason + metadata). The frontend now maps reason codes to localized messages with metadata as interpolation variables, and automatically localizes raw config-field names (e.g. "certSerial" → "证书序列号") using the existing UI-label i18n namespace. - frontend/src/utils/apiError.ts - extractApiErrorCode now prefers the string `reason` over the numeric HTTP `code`; reason is granular enough to drive i18n lookup, HTTP code is not. - New extractApiErrorMetadata to pull interpolation params off the error. - New extractI18nErrorMessage(err, t, namespace, fallback): looks up `.` in i18n and substitutes metadata. Before substitution, `metadata.key` and `metadata.keys` (slash-joined) are re-translated through `admin.settings.payment.field_` so users see "缺少必填项:证书序列号" instead of "缺少必填项:certSerial". - frontend/src/i18n/locales/{zh,en}.ts - Add payment.errors entries for every structured reason code returned by the backend (PAYMENT_DISABLED, INVALID_AMOUNT, TOO_MANY_PENDING, DAILY_LIMIT_EXCEEDED, NO_AVAILABLE_INSTANCE, PAYMENT_PROVIDER_MISCONFIGURED, WXPAY_CONFIG_MISSING_KEY / INVALID_KEY_LENGTH / INVALID_KEY, NOT_FOUND, FORBIDDEN, CONFLICT, INVALID_ORDER_TYPE, INVALID_STATUS, BALANCE_NOT_ENOUGH, REFUND_AMOUNT_EXCEEDED, REFUND_FAILED, and more), with placeholders for template variables. - 13 payment-related Vue files - Migrate catch-block error reporting from extractApiErrorMessage to extractI18nErrorMessage(err, t, 'payment.errors', fallback). - Remove the ad-hoc paymentErrorMap computed in SettingsView.vue, which the new helper supersedes (it reads i18n directly via t). - frontend/src/components/payment/providerConfig.ts - wxpay: publicKey and publicKeyId are now required (was optional), matching the pubkey-only verifier direction; certSerial is already required. This PR is drop-in safe: reason-preferring extractApiErrorCode is backward compatible with callers that pass their own i18nMap, and error codes missing from i18n fall back to the existing message-based path. --- .../components/payment/PaymentQRDialog.vue | 4 +- .../components/payment/PaymentStatusPanel.vue | 4 +- .../payment/StripePaymentInline.vue | 8 +- .../src/components/payment/providerConfig.ts | 4 +- frontend/src/i18n/locales/en.ts | 26 ++++++ frontend/src/i18n/locales/zh.ts | 26 ++++++ frontend/src/utils/apiError.ts | 84 ++++++++++++++++++- frontend/src/views/admin/SettingsView.vue | 18 ++-- .../views/admin/orders/AdminOrdersView.vue | 10 +-- .../orders/AdminPaymentDashboardView.vue | 4 +- .../admin/orders/AdminPaymentPlansView.vue | 8 +- frontend/src/views/user/PaymentQRCodeView.vue | 4 +- frontend/src/views/user/PaymentView.vue | 6 +- frontend/src/views/user/StripePaymentView.vue | 6 +- frontend/src/views/user/StripePopupView.vue | 4 +- frontend/src/views/user/UserOrdersView.vue | 8 +- 16 files changed, 177 insertions(+), 47 deletions(-) diff --git a/frontend/src/components/payment/PaymentQRDialog.vue b/frontend/src/components/payment/PaymentQRDialog.vue index db90c3b6..09d273cc 100644 --- a/frontend/src/components/payment/PaymentQRDialog.vue +++ b/frontend/src/components/payment/PaymentQRDialog.vue @@ -78,7 +78,7 @@ import Icon from '@/components/icons/Icon.vue' import { usePaymentStore } from '@/stores/payment' import { useAppStore } from '@/stores' import { paymentAPI } from '@/api/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import { getPaymentPopupFeatures } from '@/components/payment/providerConfig' import type { PaymentOrder } from '@/types/payment' import QRCode from 'qrcode' @@ -222,7 +222,7 @@ async function handleCancel() { cleanup() emit('close') } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { cancelling.value = false } diff --git a/frontend/src/components/payment/PaymentStatusPanel.vue b/frontend/src/components/payment/PaymentStatusPanel.vue index 17541e59..53989dee 100644 --- a/frontend/src/components/payment/PaymentStatusPanel.vue +++ b/frontend/src/components/payment/PaymentStatusPanel.vue @@ -124,7 +124,7 @@ import { useI18n } from 'vue-i18n' import { usePaymentStore } from '@/stores/payment' import { useAppStore } from '@/stores' import { paymentAPI } from '@/api/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import { getPaymentPopupFeatures } from '@/components/payment/providerConfig' import type { PaymentOrder } from '@/types/payment' import Icon from '@/components/icons/Icon.vue' @@ -242,7 +242,7 @@ async function handleCancel() { cleanup() outcome.value = 'cancelled' } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { cancelling.value = false } diff --git a/frontend/src/components/payment/StripePaymentInline.vue b/frontend/src/components/payment/StripePaymentInline.vue index 3ddff8c8..bdb0dd6b 100644 --- a/frontend/src/components/payment/StripePaymentInline.vue +++ b/frontend/src/components/payment/StripePaymentInline.vue @@ -67,7 +67,7 @@ import { ref, onMounted, nextTick } from 'vue' import { useI18n } from 'vue-i18n' import { useRouter } from 'vue-router' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import { paymentAPI } from '@/api/payment' import { useAppStore } from '@/stores' import { getPaymentPopupFeatures } from '@/components/payment/providerConfig' @@ -132,7 +132,7 @@ onMounted(async () => { selectedType.value = event.value.type }) } catch (err: unknown) { - initError.value = extractApiErrorMessage(err, t('payment.stripeLoadFailed')) + initError.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.stripeLoadFailed')) } finally { loading.value = false } @@ -186,7 +186,7 @@ async function handlePay() { emit('success') } } catch (err: unknown) { - error.value = extractApiErrorMessage(err, t('payment.result.failed')) + error.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.result.failed')) } finally { submitting.value = false } @@ -199,7 +199,7 @@ async function handleCancel() { await paymentAPI.cancelOrder(props.orderId) emit('back') } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { cancelling.value = false } diff --git a/frontend/src/components/payment/providerConfig.ts b/frontend/src/components/payment/providerConfig.ts index bf2d4177..f4f5acdc 100644 --- a/frontend/src/components/payment/providerConfig.ts +++ b/frontend/src/components/payment/providerConfig.ts @@ -99,9 +99,9 @@ export const PROVIDER_CONFIG_FIELDS: Record = { { key: 'mchId', label: '', sensitive: false }, { key: 'privateKey', label: '', sensitive: true }, { key: 'apiV3Key', label: '', sensitive: true }, + { key: 'certSerial', label: '', sensitive: false }, { key: 'publicKey', label: '', sensitive: true }, - { key: 'publicKeyId', label: '', sensitive: false, optional: true }, - { key: 'certSerial', label: '', sensitive: false, optional: true }, + { key: 'publicKeyId', label: '', sensitive: false }, ], stripe: [ { key: 'secretKey', label: '', sensitive: true }, diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index c0a17d96..8213cb0f 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -5432,7 +5432,33 @@ export default { errors: { tooManyPending: 'Too many pending orders (max {max}). Please complete or cancel existing orders first.', cancelRateLimited: 'Too many cancellations. Please try again later.', + // Structured error codes (reason strings from backend ApplicationError) + PAYMENT_DISABLED: 'Payment system is disabled.', + USER_INACTIVE: 'Your account is disabled.', + BALANCE_PAYMENT_DISABLED: 'Balance recharge has been disabled.', + INVALID_AMOUNT: 'Invalid amount.', + INVALID_INPUT: 'Invalid request.', + PLAN_NOT_AVAILABLE: 'Plan not found or no longer available.', + GROUP_NOT_FOUND: 'Subscription group is no longer available.', + GROUP_TYPE_MISMATCH: 'Group is not a subscription type.', + TOO_MANY_PENDING: 'Too many pending orders (max {max}). Please complete or cancel existing orders first.', + DAILY_LIMIT_EXCEEDED: 'Daily recharge limit reached. Remaining: {remaining}.', + PAYMENT_GATEWAY_ERROR: 'Payment method is unavailable.', + NO_AVAILABLE_INSTANCE: 'No payment channel available right now.', + PAYMENT_PROVIDER_MISCONFIGURED: 'Payment provider misconfigured. Please contact an administrator.', + WXPAY_CONFIG_MISSING_KEY: 'WeChat Pay config missing required key: {key}.', + WXPAY_CONFIG_INVALID_KEY_LENGTH: 'WeChat Pay {key} length is invalid (expected {expected} bytes, got {actual}).', + WXPAY_CONFIG_INVALID_KEY: 'WeChat Pay {key} is malformed. Make sure you copied the full PEM content.', PENDING_ORDERS: 'This provider has pending orders. Please wait for them to complete before making changes.', + CANCEL_RATE_LIMITED: 'Too many cancellations. Please try again later.', + NOT_FOUND: 'Order not found.', + FORBIDDEN: 'No permission for this order.', + CONFLICT: 'Order status has changed. Please refresh.', + INVALID_ORDER_TYPE: 'Only balance orders can request a refund.', + INVALID_STATUS: 'The current order status does not allow this operation.', + BALANCE_NOT_ENOUGH: 'Refund amount exceeds balance.', + REFUND_AMOUNT_EXCEEDED: 'Refund amount exceeds the recharge amount.', + REFUND_FAILED: 'Refund failed.', }, stripePay: 'Pay Now', stripeSuccessProcessing: 'Payment successful, processing your order...', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index ba9edd7f..5f936965 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -5620,7 +5620,33 @@ export default { errors: { tooManyPending: '待支付订单过多(最多 {max} 个),请先完成或取消现有订单', cancelRateLimited: '取消订单过于频繁,请稍后再试', + // Structured error codes (reason strings from backend ApplicationError) + PAYMENT_DISABLED: '支付系统已关闭', + USER_INACTIVE: '账号已被禁用', + BALANCE_PAYMENT_DISABLED: '余额充值功能已关闭', + INVALID_AMOUNT: '金额无效', + INVALID_INPUT: '参数有误', + PLAN_NOT_AVAILABLE: '套餐不存在或已下架', + GROUP_NOT_FOUND: '订阅分组不可用', + GROUP_TYPE_MISMATCH: '分组类型不是订阅类型', + TOO_MANY_PENDING: '待支付订单过多(最多 {max} 个),请先完成或取消现有订单', + DAILY_LIMIT_EXCEEDED: '今日充值已达上限,剩余额度 {remaining}', + PAYMENT_GATEWAY_ERROR: '支付方式不可用', + NO_AVAILABLE_INSTANCE: '暂无可用的支付通道', + PAYMENT_PROVIDER_MISCONFIGURED: '支付通道配置错误,请联系管理员', + WXPAY_CONFIG_MISSING_KEY: '微信支付配置缺少必填项:{key}', + WXPAY_CONFIG_INVALID_KEY_LENGTH: '微信支付 {key} 长度错误,应为 {expected} 字节(实际 {actual})', + WXPAY_CONFIG_INVALID_KEY: '微信支付 {key} 格式错误,请确认复制了完整的 PEM 内容', PENDING_ORDERS: '该服务商有未完成的订单,请等待订单完成后再操作', + CANCEL_RATE_LIMITED: '取消订单过于频繁,请稍后再试', + NOT_FOUND: '订单不存在', + FORBIDDEN: '无权限操作此订单', + CONFLICT: '订单状态已变更,请刷新', + INVALID_ORDER_TYPE: '仅余额订单可申请退款', + INVALID_STATUS: '当前订单状态不允许此操作', + BALANCE_NOT_ENOUGH: '退款金额超过余额', + REFUND_AMOUNT_EXCEEDED: '退款金额超过充值金额', + REFUND_FAILED: '退款失败', }, stripePay: '立即支付', stripeSuccessProcessing: '支付成功,正在处理订单...', diff --git a/frontend/src/utils/apiError.ts b/frontend/src/utils/apiError.ts index e1fe0c30..07a17aca 100644 --- a/frontend/src/utils/apiError.ts +++ b/frontend/src/utils/apiError.ts @@ -23,14 +23,96 @@ interface ApiErrorLike { /** * Extract the error code from an API error object. + * + * Prefers the string `reason` (e.g. "PAYMENT_PROVIDER_MISCONFIGURED") over the + * numeric HTTP `code`, because reason is granular enough to drive i18n lookup + * while HTTP code is not. */ export function extractApiErrorCode(err: unknown): string | undefined { if (!err || typeof err !== 'object') return undefined const e = err as ApiErrorLike - const code = e.code ?? e.reason ?? e.response?.data?.code + const code = e.reason ?? e.code ?? e.response?.data?.code return code != null ? String(code) : undefined } +/** + * Extract metadata (interpolation params) from an API error object. + * Backend errors carry `metadata` with template variables that fill i18n placeholders. + */ +export function extractApiErrorMetadata(err: unknown): Record | undefined { + if (!err || typeof err !== 'object') return undefined + const e = err as ApiErrorLike + return e.metadata +} + +type TranslateFn = (key: string, params?: Record) => string +type TranslateWithExistsFn = TranslateFn & { te?: (key: string) => boolean } + +/** + * Translate a value via i18n if a matching key exists, otherwise return the original. + * Example: "certSerial" → t('admin.settings.payment.field_certSerial') → "证书序列号". + */ +function tryTranslate(t: TranslateFn, key: string, fallback: string): string { + const translated = t(key) + if (translated === key) return fallback + const te = (t as TranslateWithExistsFn).te + if (te && !te(key)) return fallback + return translated +} + +/** + * Replace raw config field names in metadata (e.g. "certSerial") with their + * localized UI labels (e.g. "证书序列号"), using the provider-config field i18n namespace. + * Handles both single `key` and `/`-joined `keys` patterns used by wxpay errors. + */ +function localizeMetadata(metadata: Record, t: TranslateFn): Record { + const out: Record = { ...metadata } + if (typeof out.key === 'string') { + out.key = tryTranslate(t, `admin.settings.payment.field_${out.key}`, out.key) + } + if (typeof out.keys === 'string') { + out.keys = out.keys + .split('/') + .map(k => tryTranslate(t, `admin.settings.payment.field_${k}`, k)) + .join(' / ') + } + return out +} + +/** + * Extract a localized error message from an API error by looking up + * `.` in i18n and substituting metadata as placeholders. + * + * Config-field names in metadata (`key` / `keys`) are automatically translated + * to their UI labels before substitution, so error messages read like + * "缺少必填项:证书序列号" instead of "缺少必填项:certSerial". + * + * @param err - The caught error + * @param t - Vue i18n translate function + * @param namespace- i18n key prefix, e.g. "payment.errors" + * @param fallback - Fallback key or plain string if no localized mapping exists + */ +export function extractI18nErrorMessage( + err: unknown, + t: TranslateFn, + namespace: string, + fallback: string, +): string { + const code = extractApiErrorCode(err) + if (code) { + const key = `${namespace}.${code}` + const rawMetadata = extractApiErrorMetadata(err) ?? {} + const metadata = localizeMetadata(rawMetadata, t) + const translated = t(key, metadata) + // Vue i18n returns the key itself when missing; detect that and fall back. + if (translated !== key) return translated + // If the framework exposes `te`, use it to double-check. + const te = (t as TranslateWithExistsFn).te + if (te && te(key)) return translated + } + return extractApiErrorMessage(err, fallback) +} + /** * Extract a displayable error message from an API error. * diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index ee6a4c6d..27cb1f0c 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -2850,7 +2850,7 @@ import ProxySelector from '@/components/common/ProxySelector.vue' import ImageUpload from '@/components/common/ImageUpload.vue' import BackupSettings from '@/views/admin/BackupView.vue' import { useClipboard } from '@/composables/useClipboard' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractApiErrorMessage, extractI18nErrorMessage } from '@/utils/apiError' import { useAppStore } from '@/stores' import { useAdminSettingsStore } from '@/stores/adminSettings' import { @@ -4085,14 +4085,10 @@ const cancelRateLimitModeOptions = computed(() => [ { value: 'fixed', label: t('admin.settings.payment.cancelRateLimitWindowModeFixed') }, ]) -const paymentErrorMap = computed(() => ({ - PENDING_ORDERS: t('payment.errors.PENDING_ORDERS'), -})) - async function loadProviders() { providersLoading.value = true try { const res = await adminAPI.payment.getProviders(); providers.value = res.data || [] } - catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) } + catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { providersLoading.value = false } } @@ -4122,7 +4118,7 @@ async function handleSaveProvider(payload: Partial) { // Auto-save settings so provider changes take effect immediately await saveSettings() } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'), paymentErrorMap.value)) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { providerSaving.value = false } @@ -4148,7 +4144,7 @@ async function handleToggleField(provider: ProviderInstance, field: 'enabled' | } else { provider.allow_user_refund = newValue } - } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'), paymentErrorMap.value)) } + } catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } } async function handleToggleType(provider: ProviderInstance, type: string) { @@ -4158,7 +4154,7 @@ async function handleToggleType(provider: ProviderInstance, type: string) { try { await adminAPI.payment.updateProvider(provider.id, { supported_types: updated } as any) provider.supported_types = updated - } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'), paymentErrorMap.value)) } + } catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } } function confirmDeleteProvider(provider: ProviderInstance) { @@ -4177,7 +4173,7 @@ async function handleReorderProviders(updates: { id: number; sort_order: number if (p) p.sort_order = u.sort_order } } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) loadProviders() } } @@ -4189,7 +4185,7 @@ async function handleDeleteProvider() { appStore.showSuccess(t('common.deleted')) showDeleteProviderDialog.value = false loadProviders() - } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'), paymentErrorMap.value)) } + } catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } } onMounted(() => { diff --git a/frontend/src/views/admin/orders/AdminOrdersView.vue b/frontend/src/views/admin/orders/AdminOrdersView.vue index 027c8e5e..dd9fa7e6 100644 --- a/frontend/src/views/admin/orders/AdminOrdersView.vue +++ b/frontend/src/views/admin/orders/AdminOrdersView.vue @@ -116,7 +116,7 @@ import { ref, reactive, computed, onMounted } from 'vue' import { useI18n } from 'vue-i18n' import { useAppStore } from '@/stores/app' import { adminPaymentAPI } from '@/api/admin/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import { formatOrderDateTime } from '@/components/payment/orderUtils' import type { PaymentOrder } from '@/types/payment' import AppLayout from '@/components/layout/AppLayout.vue' @@ -167,7 +167,7 @@ async function loadOrders() { orders.value = res.data.items || [] orderPagination.total = res.data.total || 0 } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { ordersLoading.value = false } } @@ -214,12 +214,12 @@ async function showOrderDetail(order: PaymentOrder) { async function handleCancelOrder(order: PaymentOrder) { try { await adminPaymentAPI.cancelOrder(order.id); appStore.showSuccess(t('payment.admin.orderCancelled')); loadOrders() } - catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) } + catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } } async function handleRetryOrder(order: PaymentOrder) { try { await adminPaymentAPI.retryRecharge(order.id); appStore.showSuccess(t('payment.admin.retrySuccess')); loadOrders() } - catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) } + catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } } function openRefundDialog(order: PaymentOrder) { selectedOrder.value = order; showRefundDialog.value = true } @@ -230,7 +230,7 @@ async function handleRefund(data: { amount: number; reason: string; deduct_balan try { await adminPaymentAPI.refundOrder(selectedOrder.value.id, { amount: data.amount, reason: data.reason, deduct_balance: data.deduct_balance, force: data.force }) appStore.showSuccess(t('payment.admin.refundSuccess')); showRefundDialog.value = false; loadOrders() - } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) } + } catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { refundSubmitting.value = false } } diff --git a/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue b/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue index 06bc9218..5a80db44 100644 --- a/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue +++ b/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue @@ -72,7 +72,7 @@ import { ref, watch, onMounted } from 'vue' import { useI18n } from 'vue-i18n' import { useAppStore } from '@/stores/app' import { adminPaymentAPI } from '@/api/admin/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import type { DashboardStats } from '@/types/payment' import AppLayout from '@/components/layout/AppLayout.vue' import LoadingSpinner from '@/components/common/LoadingSpinner.vue' @@ -110,7 +110,7 @@ async function loadDashboard() { const res = await adminPaymentAPI.getDashboard(days.value) stats.value = res.data } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { loading.value = false } diff --git a/frontend/src/views/admin/orders/AdminPaymentPlansView.vue b/frontend/src/views/admin/orders/AdminPaymentPlansView.vue index 876b2aa1..c2fc26fe 100644 --- a/frontend/src/views/admin/orders/AdminPaymentPlansView.vue +++ b/frontend/src/views/admin/orders/AdminPaymentPlansView.vue @@ -78,7 +78,7 @@ import { ref, computed, onMounted } from 'vue' import { useI18n } from 'vue-i18n' import { useAppStore } from '@/stores/app' import { adminPaymentAPI } from '@/api/admin/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import adminAPI from '@/api/admin' import type { SubscriptionPlan } from '@/types/payment' import type { AdminGroup } from '@/types' @@ -150,7 +150,7 @@ async function loadPlans() { : (p.features || []), })) } - catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) } + catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { plansLoading.value = false } } @@ -166,7 +166,7 @@ async function toggleForSale(plan: SubscriptionPlan) { await adminPaymentAPI.updatePlan(plan.id, { for_sale: !plan.for_sale }) plan.for_sale = !plan.for_sale } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } } @@ -174,7 +174,7 @@ function confirmDeletePlan(plan: SubscriptionPlan) { deletingPlanId.value = plan async function handleDeletePlan() { if (!deletingPlanId.value) return try { await adminPaymentAPI.deletePlan(deletingPlanId.value); appStore.showSuccess(t('common.deleted')); showDeletePlanDialog.value = false; loadPlans() } - catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) } + catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } } // ==================== Lifecycle ==================== diff --git a/frontend/src/views/user/PaymentQRCodeView.vue b/frontend/src/views/user/PaymentQRCodeView.vue index 0965947a..f844858d 100644 --- a/frontend/src/views/user/PaymentQRCodeView.vue +++ b/frontend/src/views/user/PaymentQRCodeView.vue @@ -39,7 +39,7 @@ import { useRoute, useRouter } from 'vue-router' import AppLayout from '@/components/layout/AppLayout.vue' import { usePaymentStore } from '@/stores/payment' import { paymentAPI } from '@/api/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import { useAppStore } from '@/stores' import QRCode from 'qrcode' import alipayIcon from '@/assets/icons/alipay.svg' @@ -167,7 +167,7 @@ async function handleCancel() { cleanup() router.push('/purchase') } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { cancelling.value = false } diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue index 3f1401b3..e2885c80 100644 --- a/frontend/src/views/user/PaymentView.vue +++ b/frontend/src/views/user/PaymentView.vue @@ -271,7 +271,7 @@ import { usePaymentStore } from '@/stores/payment' import { useSubscriptionStore } from '@/stores/subscriptions' import { useAppStore } from '@/stores' import { paymentAPI } from '@/api/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import { isMobileDevice } from '@/utils/device' import type { SubscriptionPlan, CheckoutInfoResponse, OrderType } from '@/types/payment' import AppLayout from '@/components/layout/AppLayout.vue' @@ -610,7 +610,7 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n } else if (apiErr.reason === 'CANCEL_RATE_LIMITED') { errorMessage.value = t('payment.errors.cancelRateLimited') } else { - errorMessage.value = extractApiErrorMessage(err, t('payment.result.failed')) + errorMessage.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.result.failed')) } appStore.showError(errorMessage.value) } finally { @@ -648,7 +648,7 @@ onMounted(async () => { } } } - } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) } + } catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { loading.value = false } // Fetch active subscriptions (uses cache, non-blocking) subscriptionStore.fetchActiveSubscriptions().catch(() => {}) diff --git a/frontend/src/views/user/StripePaymentView.vue b/frontend/src/views/user/StripePaymentView.vue index 20a4a408..3f73d4d5 100644 --- a/frontend/src/views/user/StripePaymentView.vue +++ b/frontend/src/views/user/StripePaymentView.vue @@ -99,7 +99,7 @@ import { useI18n } from 'vue-i18n' import { useRoute, useRouter } from 'vue-router' import { usePaymentStore } from '@/stores/payment' import { paymentAPI } from '@/api/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import { isMobileDevice } from '@/utils/device' import type { PaymentOrder } from '@/types/payment' import type { Stripe, StripeElements } from '@stripe/stripe-js' @@ -167,7 +167,7 @@ onMounted(async () => { mountPaymentElement(stripe, clientSecret) } } catch (err: unknown) { - initError.value = extractApiErrorMessage(err, t('payment.stripeLoadFailed')) + initError.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.stripeLoadFailed')) } finally { loading.value = false } @@ -248,7 +248,7 @@ async function handleGenericPay() { scheduleClose() } } catch (err: unknown) { - stripeError.value = extractApiErrorMessage(err, t('payment.result.failed')) + stripeError.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.result.failed')) } finally { stripeSubmitting.value = false } diff --git a/frontend/src/views/user/StripePopupView.vue b/frontend/src/views/user/StripePopupView.vue index 2704c62d..688ad644 100644 --- a/frontend/src/views/user/StripePopupView.vue +++ b/frontend/src/views/user/StripePopupView.vue @@ -56,7 +56,7 @@ import { computed, ref, onMounted, onUnmounted } from 'vue' import { useI18n } from 'vue-i18n' import { useRoute } from 'vue-router' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import { isMobileDevice } from '@/utils/device' interface StripeWithWechatPay { @@ -143,7 +143,7 @@ async function initStripe(clientSecret: string, publishableKey: string) { } } } catch (err: unknown) { - error.value = extractApiErrorMessage(err, t('payment.stripeLoadFailed')) + error.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.stripeLoadFailed')) } } diff --git a/frontend/src/views/user/UserOrdersView.vue b/frontend/src/views/user/UserOrdersView.vue index ea888eb7..c3ed80eb 100644 --- a/frontend/src/views/user/UserOrdersView.vue +++ b/frontend/src/views/user/UserOrdersView.vue @@ -86,7 +86,7 @@ import { useI18n } from 'vue-i18n' import { useRouter } from 'vue-router' import { useAppStore } from '@/stores' import { paymentAPI } from '@/api/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import type { PaymentOrder } from '@/types/payment' import AppLayout from '@/components/layout/AppLayout.vue' import Pagination from '@/components/common/Pagination.vue' @@ -128,7 +128,7 @@ async function fetchOrders() { orders.value = res.data.items || [] pagination.total = res.data.total || 0 } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { loading.value = false } @@ -148,7 +148,7 @@ async function confirmCancel() { cancelTargetId.value = null await fetchOrders() } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { actionLoading.value = false } @@ -166,7 +166,7 @@ async function confirmRefund() { refundReason.value = '' await fetchOrders() } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { actionLoading.value = false } From 97c9b992cbf8b658b6ef27c27fd0041893b74317 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 20:27:15 +0800 Subject: [PATCH 034/326] fix: require wechat unionid for oauth identity --- backend/internal/handler/auth_wechat_oauth.go | 6 +- .../handler/auth_wechat_oauth_test.go | 107 +++++++++++------- 2 files changed, 66 insertions(+), 47 deletions(-) diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 816f60fd..f0755f1f 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -193,11 +193,11 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { unionid := strings.TrimSpace(firstNonEmpty(userInfo.UnionID, tokenResp.UnionID)) openid := strings.TrimSpace(firstNonEmpty(userInfo.OpenID, tokenResp.OpenID)) - providerSubject := firstNonEmpty(unionid, openid) - if providerSubject == "" { - redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_subject", "") + if unionid == "" { + redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_unionid", "") return } + providerSubject := unionid username := firstNonEmpty(userInfo.Nickname, wechatFallbackUsername(providerSubject)) email := wechatSyntheticEmail(providerSubject) diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index 0d1df1b6..1ff80e1b 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -6,7 +6,6 @@ import ( "bytes" "context" "database/sql" - "encoding/base64" "net/http" "net/http/httptest" "net/url" @@ -122,6 +121,59 @@ func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) { require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"]) } +func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "https://app.example.com/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Contains(t, recorder.Header().Get("Location"), "#error=provider_error") + require.Contains(t, recorder.Header().Get("Location"), "error_message=wechat_missing_unionid") + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + + count, err := client.PendingAuthSession.Query().Count(context.Background()) + require.NoError(t, err) + require.Zero(t, count) +} + func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) { testCases := []struct { name string @@ -542,12 +594,7 @@ func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandl userRepo := &oauthPendingFlowUserRepo{client: client} redeemRepo := repository.NewRedeemCodeRepository(client) - settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{ - values: map[string]string{ - service.SettingKeyRegistrationEnabled: "true", - service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), - }, - }, &config.Config{ + cfg := &config.Config{ JWT: config.JWTConfig{ Secret: "test-secret", ExpireHour: 1, @@ -558,25 +605,20 @@ func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandl UserBalance: 0, UserConcurrency: 1, }, - }) + } + settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), + }, + }, cfg) authSvc := service.NewAuthService( client, userRepo, redeemRepo, &wechatOAuthRefreshTokenCacheStub{}, - &config.Config{ - JWT: config.JWTConfig{ - Secret: "test-secret", - ExpireHour: 1, - AccessTokenExpireMinutes: 60, - RefreshTokenExpireDays: 7, - }, - Default: config.DefaultConfig{ - UserBalance: 0, - UserConcurrency: 1, - }, - }, + cfg, settingSvc, nil, nil, @@ -588,33 +630,10 @@ func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandl return &AuthHandler{ authService: authSvc, settingSvc: settingSvc, + cfg: cfg, }, client } -func encodedCookie(name, value string) *http.Cookie { - return &http.Cookie{ - Name: name, - Value: encodeCookieValue(value), - Path: "/", - } -} - -func findCookie(cookies []*http.Cookie, name string) *http.Cookie { - for _, cookie := range cookies { - if cookie.Name == name { - return cookie - } - } - return nil -} - -func decodeCookieValueForTest(t *testing.T, value string) string { - t.Helper() - raw, err := base64.RawURLEncoding.DecodeString(value) - require.NoError(t, err) - return string(raw) -} - func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) { t.Helper() From 7c7924e9fa6125f47e2b080496466ab22cec40ad Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 20:31:19 +0800 Subject: [PATCH 035/326] fix: guard payment fulfillment provider mismatch --- .../internal/service/payment_fulfillment.go | 25 ++++++++ .../service/payment_fulfillment_test.go | 64 +++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 44818b37..519455f0 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -41,6 +41,19 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo slog.Error("order not found", "orderID", oid) return nil } + instanceProviderKey := "" + if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil { + instanceProviderKey = inst.ProviderKey + } + expectedProviderKey := expectedNotificationProviderKey(s.registry, o.PaymentType, instanceProviderKey) + if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) { + s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{ + "expectedProvider": expectedProviderKey, + "actualProvider": pk, + "tradeNo": tradeNo, + }) + return fmt.Errorf("provider mismatch: expected %s, got %s", expectedProviderKey, pk) + } // Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount). // Also skip if paid is NaN/Inf (malformed provider data). if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) { @@ -56,6 +69,18 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo return s.toPaid(ctx, o, tradeNo, paid, pk) } +func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, instanceProviderKey string) string { + if key := strings.TrimSpace(instanceProviderKey); key != "" { + return key + } + if registry != nil { + if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(orderPaymentType))); key != "" { + return key + } + } + return strings.TrimSpace(orderPaymentType) +} + func (s *PaymentService) toPaid(ctx context.Context, o *dbent.PaymentOrder, tradeNo string, paid float64, pk string) error { previousStatus := o.Status now := time.Now() diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go index 625b0d9f..4cc00301 100644 --- a/backend/internal/service/payment_fulfillment_test.go +++ b/backend/internal/service/payment_fulfillment_test.go @@ -3,12 +3,37 @@ package service import ( + "context" "errors" "testing" + "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/stretchr/testify/assert" ) +type paymentFulfillmentTestProvider struct { + key string + supportedTypes []payment.PaymentType +} + +func (p paymentFulfillmentTestProvider) Name() string { return p.key } +func (p paymentFulfillmentTestProvider) ProviderKey() string { return p.key } +func (p paymentFulfillmentTestProvider) SupportedTypes() []payment.PaymentType { + return p.supportedTypes +} +func (p paymentFulfillmentTestProvider) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { + panic("unexpected call") +} +func (p paymentFulfillmentTestProvider) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) { + panic("unexpected call") +} +func (p paymentFulfillmentTestProvider) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) { + panic("unexpected call") +} +func (p paymentFulfillmentTestProvider) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) { + panic("unexpected call") +} + // --------------------------------------------------------------------------- // resolveRedeemAction — pure idempotency decision logic // --------------------------------------------------------------------------- @@ -161,3 +186,42 @@ func TestResolveRedeemAction_IsUsedCanUseConsistency(t *testing.T) { assert.True(t, unusedCode.CanUse()) assert.Equal(t, redeemActionRedeem, resolveRedeemAction(unusedCode, nil)) } + +func TestExpectedNotificationProviderKeyPrefersOrderInstanceProvider(t *testing.T) { + t.Parallel() + + registry := payment.NewRegistry() + registry.Register(paymentFulfillmentTestProvider{ + key: payment.TypeAlipay, + supportedTypes: []payment.PaymentType{payment.TypeAlipay}, + }) + + assert.Equal(t, + payment.TypeEasyPay, + expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay), + ) +} + +func TestExpectedNotificationProviderKeyUsesRegistryMappingForLegacyOrders(t *testing.T) { + t.Parallel() + + registry := payment.NewRegistry() + registry.Register(paymentFulfillmentTestProvider{ + key: payment.TypeEasyPay, + supportedTypes: []payment.PaymentType{payment.TypeAlipay}, + }) + + assert.Equal(t, + payment.TypeEasyPay, + expectedNotificationProviderKey(registry, payment.TypeAlipay, ""), + ) +} + +func TestExpectedNotificationProviderKeyFallsBackToPaymentType(t *testing.T) { + t.Parallel() + + assert.Equal(t, + payment.TypeWxpay, + expectedNotificationProviderKey(nil, payment.TypeWxpay, ""), + ) +} From e3f69e02464fa79153aa16db7a457359f6a9a8a3 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 20:42:01 +0800 Subject: [PATCH 036/326] fix: tighten webhook provider resolution --- backend/internal/service/payment_service.go | 20 --- .../service/payment_webhook_provider.go | 86 +++++++++++ .../service/payment_webhook_provider_test.go | 141 ++++++++++++++++++ 3 files changed, 227 insertions(+), 20 deletions(-) create mode 100644 backend/internal/service/payment_webhook_provider.go create mode 100644 backend/internal/service/payment_webhook_provider_test.go diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go index e897741a..d3175ba6 100644 --- a/backend/internal/service/payment_service.go +++ b/backend/internal/service/payment_service.go @@ -9,7 +9,6 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" - "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment/provider" @@ -225,25 +224,6 @@ func (s *PaymentService) loadProviders(ctx context.Context) { } } -// GetWebhookProvider returns the provider instance that should verify a webhook. -// It extracts out_trade_no from the raw body, looks up the order to find the -// original provider instance, and creates a provider with that instance's credentials. -// Falls back to the registry provider when the order cannot be found. -func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) { - if outTradeNo != "" { - order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx) - if err == nil { - p, pErr := s.getOrderProvider(ctx, order) - if pErr == nil { - return p, nil - } - slog.Warn("[Webhook] order provider creation failed, falling back to registry", "outTradeNo", outTradeNo, "error", pErr) - } - } - s.EnsureProviders(ctx) - return s.registry.GetProviderByKey(providerKey) -} - // --- Helpers --- func psIsRefundStatus(s string) bool { diff --git a/backend/internal/service/payment_webhook_provider.go b/backend/internal/service/payment_webhook_provider.go new file mode 100644 index 00000000..a877db2b --- /dev/null +++ b/backend/internal/service/payment_webhook_provider.go @@ -0,0 +1,86 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + "strconv" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/Wei-Shaw/sub2api/internal/payment/provider" +) + +// GetWebhookProvider returns the provider instance that should verify a webhook. +// It resolves the original provider instance from the order whenever possible and +// only falls back to a registry provider for legacy/single-instance scenarios. +func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) { + if outTradeNo != "" { + order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx) + if err == nil { + if psHasPinnedProviderInstance(order) { + return s.getPinnedOrderProvider(ctx, order) + } + if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { + return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey) + } + s.EnsureProviders(ctx) + return s.registry.GetProviderByKey(providerKey) + } + } + + if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { + return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey) + } + + s.EnsureProviders(ctx) + return s.registry.GetProviderByKey(providerKey) +} + +func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { + inst, err := s.getOrderProviderInstance(ctx, o) + if err != nil { + return nil, fmt.Errorf("load order provider instance: %w", err) + } + if inst == nil { + return nil, fmt.Errorf("order %d provider instance is missing", o.ID) + } + + instID := strconv.FormatInt(int64(inst.ID), 10) + cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID)) + if err != nil { + return nil, fmt.Errorf("load provider instance config: %w", err) + } + + prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg) + if err != nil { + return nil, fmt.Errorf("create pinned provider: %w", err) + } + return prov, nil +} + +func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, providerKey string) bool { + providerKey = strings.TrimSpace(providerKey) + if providerKey == "" || s == nil || s.entClient == nil { + return false + } + + count, err := s.entClient.PaymentProviderInstance.Query(). + Where( + paymentproviderinstance.ProviderKeyEQ(providerKey), + paymentproviderinstance.EnabledEQ(true), + ). + Count(ctx) + if err != nil { + slog.Warn("payment webhook fallback instance count failed", "provider", providerKey, "error", err) + return false + } + return count <= 1 +} + +func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool { + return order != nil && order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != "" +} diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go new file mode 100644 index 00000000..85c296de --- /dev/null +++ b/backend/internal/service/payment_webhook_provider_test.go @@ -0,0 +1,141 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/stretchr/testify/require" +) + +type webhookProviderTestDouble struct { + key string + types []payment.PaymentType +} + +func (p webhookProviderTestDouble) Name() string { return p.key } +func (p webhookProviderTestDouble) ProviderKey() string { return p.key } +func (p webhookProviderTestDouble) SupportedTypes() []payment.PaymentType { return p.types } +func (p webhookProviderTestDouble) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { + panic("unexpected call") +} +func (p webhookProviderTestDouble) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) { + panic("unexpected call") +} +func (p webhookProviderTestDouble) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) { + panic("unexpected call") +} +func (p webhookProviderTestDouble) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) { + panic("unexpected call") +} + +func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-b"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + registry: payment.NewRegistry(), + providersLoaded: true, + } + + _, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "") + require.Error(t, err) + require.Contains(t, err.Error(), "ambiguous") +} + +func TestGetWebhookProviderAllowsSingleInstanceRegistryFallback(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-a"). + SetConfig("{}"). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + registry := payment.NewRegistry() + registry.Register(webhookProviderTestDouble{ + key: payment.TypeStripe, + types: []payment.PaymentType{payment.TypeStripe}, + }) + + svc := &PaymentService{ + entClient: client, + registry: registry, + providersLoaded: true, + } + + prov, err := svc.GetWebhookProvider(ctx, payment.TypeStripe, "") + require.NoError(t, err) + require.Equal(t, payment.TypeStripe, prov.ProviderKey()) +} + +func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("webhook@example.com"). + SetPasswordHash("hash"). + SetUsername("webhook"). + Save(ctx) + require.NoError(t, err) + + pinnedInstanceID := "999" + _, err = client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("TEST-RECHARGE"). + SetOutTradeNo("sub2_test_pinned_order"). + SetPaymentType(payment.TypeWxpay). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID(pinnedInstanceID). + Save(ctx) + require.NoError(t, err) + + registry := payment.NewRegistry() + registry.Register(webhookProviderTestDouble{ + key: payment.TypeWxpay, + types: []payment.PaymentType{payment.TypeWxpay}, + }) + + svc := &PaymentService{ + entClient: client, + registry: registry, + providersLoaded: true, + } + + _, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "sub2_test_pinned_order") + require.Error(t, err) + require.Contains(t, err.Error(), "provider instance") +} From c0b24aefba926c61d749af11cd01cd6f96bb65fe Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 20:47:14 +0800 Subject: [PATCH 037/326] feat: snapshot payment provider keys on orders --- backend/ent/migrate/schema.go | 15 ++-- backend/ent/mutation.go | 75 ++++++++++++++++- backend/ent/paymentorder.go | 16 +++- backend/ent/paymentorder/paymentorder.go | 10 +++ backend/ent/paymentorder/where.go | 80 ++++++++++++++++++ backend/ent/paymentorder_create.go | 83 +++++++++++++++++++ backend/ent/paymentorder_update.go | 62 ++++++++++++++ backend/ent/runtime/runtime.go | 20 +++-- backend/ent/schema/payment_order.go | 4 + .../internal/service/payment_fulfillment.go | 7 +- .../service/payment_fulfillment_test.go | 21 ++++- backend/internal/service/payment_order.go | 8 +- .../service/payment_order_lifecycle.go | 13 ++- ...dd_payment_order_provider_key_snapshot.sql | 10 +++ 14 files changed, 400 insertions(+), 24 deletions(-) create mode 100644 backend/migrations/112_add_payment_order_provider_key_snapshot.sql diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index bf41e73b..230ea060 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -654,6 +654,7 @@ var ( {Name: "subscription_group_id", Type: field.TypeInt64, Nullable: true}, {Name: "subscription_days", Type: field.TypeInt, Nullable: true}, {Name: "provider_instance_id", Type: field.TypeString, Nullable: true, Size: 64}, + {Name: "provider_key", Type: field.TypeString, Nullable: true, Size: 30}, {Name: "status", Type: field.TypeString, Size: 30, Default: "PENDING"}, {Name: "refund_amount", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,2)"}}, {Name: "refund_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, @@ -682,7 +683,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "payment_orders_users_payment_orders", - Columns: []*schema.Column{PaymentOrdersColumns[37]}, + Columns: []*schema.Column{PaymentOrdersColumns[38]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, @@ -696,32 +697,32 @@ var ( { Name: "paymentorder_user_id", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[37]}, + Columns: []*schema.Column{PaymentOrdersColumns[38]}, }, { Name: "paymentorder_status", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[19]}, + Columns: []*schema.Column{PaymentOrdersColumns[20]}, }, { Name: "paymentorder_expires_at", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[27]}, + Columns: []*schema.Column{PaymentOrdersColumns[28]}, }, { Name: "paymentorder_created_at", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[35]}, + Columns: []*schema.Column{PaymentOrdersColumns[36]}, }, { Name: "paymentorder_paid_at", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[28]}, + Columns: []*schema.Column{PaymentOrdersColumns[29]}, }, { Name: "paymentorder_payment_type_paid_at", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[28]}, + Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[29]}, }, { Name: "paymentorder_order_type", diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 12905c9a..5227015c 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -15385,6 +15385,7 @@ type PaymentOrderMutation struct { subscription_days *int addsubscription_days *int provider_instance_id *string + provider_key *string status *string refund_amount *float64 addrefund_amount *float64 @@ -16421,6 +16422,55 @@ func (m *PaymentOrderMutation) ResetProviderInstanceID() { delete(m.clearedFields, paymentorder.FieldProviderInstanceID) } +// SetProviderKey sets the "provider_key" field. +func (m *PaymentOrderMutation) SetProviderKey(s string) { + m.provider_key = &s +} + +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *PaymentOrderMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key + if v == nil { + return + } + return *v, true +} + +// OldProviderKey returns the old "provider_key" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldProviderKey(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + } + return oldValue.ProviderKey, nil +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (m *PaymentOrderMutation) ClearProviderKey() { + m.provider_key = nil + m.clearedFields[paymentorder.FieldProviderKey] = struct{}{} +} + +// ProviderKeyCleared returns if the "provider_key" field was cleared in this mutation. +func (m *PaymentOrderMutation) ProviderKeyCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldProviderKey] + return ok +} + +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *PaymentOrderMutation) ResetProviderKey() { + m.provider_key = nil + delete(m.clearedFields, paymentorder.FieldProviderKey) +} + // SetStatus sets the "status" field. func (m *PaymentOrderMutation) SetStatus(s string) { m.status = &s @@ -17280,7 +17330,7 @@ func (m *PaymentOrderMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *PaymentOrderMutation) Fields() []string { - fields := make([]string, 0, 37) + fields := make([]string, 0, 38) if m.user != nil { fields = append(fields, paymentorder.FieldUserID) } @@ -17338,6 +17388,9 @@ func (m *PaymentOrderMutation) Fields() []string { if m.provider_instance_id != nil { fields = append(fields, paymentorder.FieldProviderInstanceID) } + if m.provider_key != nil { + fields = append(fields, paymentorder.FieldProviderKey) + } if m.status != nil { fields = append(fields, paymentorder.FieldStatus) } @@ -17438,6 +17491,8 @@ func (m *PaymentOrderMutation) Field(name string) (ent.Value, bool) { return m.SubscriptionDays() case paymentorder.FieldProviderInstanceID: return m.ProviderInstanceID() + case paymentorder.FieldProviderKey: + return m.ProviderKey() case paymentorder.FieldStatus: return m.Status() case paymentorder.FieldRefundAmount: @@ -17521,6 +17576,8 @@ func (m *PaymentOrderMutation) OldField(ctx context.Context, name string) (ent.V return m.OldSubscriptionDays(ctx) case paymentorder.FieldProviderInstanceID: return m.OldProviderInstanceID(ctx) + case paymentorder.FieldProviderKey: + return m.OldProviderKey(ctx) case paymentorder.FieldStatus: return m.OldStatus(ctx) case paymentorder.FieldRefundAmount: @@ -17699,6 +17756,13 @@ func (m *PaymentOrderMutation) SetField(name string, value ent.Value) error { } m.SetProviderInstanceID(v) return nil + case paymentorder.FieldProviderKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderKey(v) + return nil case paymentorder.FieldStatus: v, ok := value.(string) if !ok { @@ -17966,6 +18030,9 @@ func (m *PaymentOrderMutation) ClearedFields() []string { if m.FieldCleared(paymentorder.FieldProviderInstanceID) { fields = append(fields, paymentorder.FieldProviderInstanceID) } + if m.FieldCleared(paymentorder.FieldProviderKey) { + fields = append(fields, paymentorder.FieldProviderKey) + } if m.FieldCleared(paymentorder.FieldRefundReason) { fields = append(fields, paymentorder.FieldRefundReason) } @@ -18034,6 +18101,9 @@ func (m *PaymentOrderMutation) ClearField(name string) error { case paymentorder.FieldProviderInstanceID: m.ClearProviderInstanceID() return nil + case paymentorder.FieldProviderKey: + m.ClearProviderKey() + return nil case paymentorder.FieldRefundReason: m.ClearRefundReason() return nil @@ -18129,6 +18199,9 @@ func (m *PaymentOrderMutation) ResetField(name string) error { case paymentorder.FieldProviderInstanceID: m.ResetProviderInstanceID() return nil + case paymentorder.FieldProviderKey: + m.ResetProviderKey() + return nil case paymentorder.FieldStatus: m.ResetStatus() return nil diff --git a/backend/ent/paymentorder.go b/backend/ent/paymentorder.go index 6ea3e709..a58823ee 100644 --- a/backend/ent/paymentorder.go +++ b/backend/ent/paymentorder.go @@ -56,6 +56,8 @@ type PaymentOrder struct { SubscriptionDays *int `json:"subscription_days,omitempty"` // ProviderInstanceID holds the value of the "provider_instance_id" field. ProviderInstanceID *string `json:"provider_instance_id,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey *string `json:"provider_key,omitempty"` // Status holds the value of the "status" field. Status string `json:"status,omitempty"` // RefundAmount holds the value of the "refund_amount" field. @@ -129,7 +131,7 @@ func (*PaymentOrder) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case paymentorder.FieldID, paymentorder.FieldUserID, paymentorder.FieldPlanID, paymentorder.FieldSubscriptionGroupID, paymentorder.FieldSubscriptionDays: values[i] = new(sql.NullInt64) - case paymentorder.FieldUserEmail, paymentorder.FieldUserName, paymentorder.FieldUserNotes, paymentorder.FieldRechargeCode, paymentorder.FieldOutTradeNo, paymentorder.FieldPaymentType, paymentorder.FieldPaymentTradeNo, paymentorder.FieldPayURL, paymentorder.FieldQrCode, paymentorder.FieldQrCodeImg, paymentorder.FieldOrderType, paymentorder.FieldProviderInstanceID, paymentorder.FieldStatus, paymentorder.FieldRefundReason, paymentorder.FieldRefundRequestReason, paymentorder.FieldRefundRequestedBy, paymentorder.FieldFailedReason, paymentorder.FieldClientIP, paymentorder.FieldSrcHost, paymentorder.FieldSrcURL: + case paymentorder.FieldUserEmail, paymentorder.FieldUserName, paymentorder.FieldUserNotes, paymentorder.FieldRechargeCode, paymentorder.FieldOutTradeNo, paymentorder.FieldPaymentType, paymentorder.FieldPaymentTradeNo, paymentorder.FieldPayURL, paymentorder.FieldQrCode, paymentorder.FieldQrCodeImg, paymentorder.FieldOrderType, paymentorder.FieldProviderInstanceID, paymentorder.FieldProviderKey, paymentorder.FieldStatus, paymentorder.FieldRefundReason, paymentorder.FieldRefundRequestReason, paymentorder.FieldRefundRequestedBy, paymentorder.FieldFailedReason, paymentorder.FieldClientIP, paymentorder.FieldSrcHost, paymentorder.FieldSrcURL: values[i] = new(sql.NullString) case paymentorder.FieldRefundAt, paymentorder.FieldRefundRequestedAt, paymentorder.FieldExpiresAt, paymentorder.FieldPaidAt, paymentorder.FieldCompletedAt, paymentorder.FieldFailedAt, paymentorder.FieldCreatedAt, paymentorder.FieldUpdatedAt: values[i] = new(sql.NullTime) @@ -276,6 +278,13 @@ func (_m *PaymentOrder) assignValues(columns []string, values []any) error { _m.ProviderInstanceID = new(string) *_m.ProviderInstanceID = value.String } + case paymentorder.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = new(string) + *_m.ProviderKey = value.String + } case paymentorder.FieldStatus: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field status", values[i]) @@ -508,6 +517,11 @@ func (_m *PaymentOrder) String() string { builder.WriteString(*v) } builder.WriteString(", ") + if v := _m.ProviderKey; v != nil { + builder.WriteString("provider_key=") + builder.WriteString(*v) + } + builder.WriteString(", ") builder.WriteString("status=") builder.WriteString(_m.Status) builder.WriteString(", ") diff --git a/backend/ent/paymentorder/paymentorder.go b/backend/ent/paymentorder/paymentorder.go index 4467b2b6..af9b1422 100644 --- a/backend/ent/paymentorder/paymentorder.go +++ b/backend/ent/paymentorder/paymentorder.go @@ -52,6 +52,8 @@ const ( FieldSubscriptionDays = "subscription_days" // FieldProviderInstanceID holds the string denoting the provider_instance_id field in the database. FieldProviderInstanceID = "provider_instance_id" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" // FieldStatus holds the string denoting the status field in the database. FieldStatus = "status" // FieldRefundAmount holds the string denoting the refund_amount field in the database. @@ -123,6 +125,7 @@ var Columns = []string{ FieldSubscriptionGroupID, FieldSubscriptionDays, FieldProviderInstanceID, + FieldProviderKey, FieldStatus, FieldRefundAmount, FieldRefundReason, @@ -176,6 +179,8 @@ var ( OrderTypeValidator func(string) error // ProviderInstanceIDValidator is a validator for the "provider_instance_id" field. It is called by the builders before save. ProviderInstanceIDValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error // DefaultStatus holds the default value on creation for the "status" field. DefaultStatus string // StatusValidator is a validator for the "status" field. It is called by the builders before save. @@ -301,6 +306,11 @@ func ByProviderInstanceID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldProviderInstanceID, opts...).ToFunc() } +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + // ByStatus orders the results by the status field. func ByStatus(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldStatus, opts...).ToFunc() diff --git a/backend/ent/paymentorder/where.go b/backend/ent/paymentorder/where.go index 78520fac..0f6b74a0 100644 --- a/backend/ent/paymentorder/where.go +++ b/backend/ent/paymentorder/where.go @@ -150,6 +150,11 @@ func ProviderInstanceID(v string) predicate.PaymentOrder { return predicate.PaymentOrder(sql.FieldEQ(FieldProviderInstanceID, v)) } +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldEQ(FieldProviderKey, v)) +} + // Status applies equality check predicate on the "status" field. It's identical to StatusEQ. func Status(v string) predicate.PaymentOrder { return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v)) @@ -1360,6 +1365,81 @@ func ProviderInstanceIDContainsFold(v string) predicate.PaymentOrder { return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderInstanceID, v)) } +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyIsNil applies the IsNil predicate on the "provider_key" field. +func ProviderKeyIsNil() predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldIsNull(FieldProviderKey)) +} + +// ProviderKeyNotNil applies the NotNil predicate on the "provider_key" field. +func ProviderKeyNotNil() predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldNotNull(FieldProviderKey)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderKey, v)) +} + // StatusEQ applies the EQ predicate on the "status" field. func StatusEQ(v string) predicate.PaymentOrder { return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v)) diff --git a/backend/ent/paymentorder_create.go b/backend/ent/paymentorder_create.go index 03098339..497ba52c 100644 --- a/backend/ent/paymentorder_create.go +++ b/backend/ent/paymentorder_create.go @@ -225,6 +225,20 @@ func (_c *PaymentOrderCreate) SetNillableProviderInstanceID(v *string) *PaymentO return _c } +// SetProviderKey sets the "provider_key" field. +func (_c *PaymentOrderCreate) SetProviderKey(v string) *PaymentOrderCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_c *PaymentOrderCreate) SetNillableProviderKey(v *string) *PaymentOrderCreate { + if v != nil { + _c.SetProviderKey(*v) + } + return _c +} + // SetStatus sets the "status" field. func (_c *PaymentOrderCreate) SetStatus(v string) *PaymentOrderCreate { _c.mutation.SetStatus(v) @@ -602,6 +616,11 @@ func (_c *PaymentOrderCreate) check() error { return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)} } } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := paymentorder.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)} + } + } if _, ok := _c.mutation.Status(); !ok { return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "PaymentOrder.status"`)} } @@ -748,6 +767,10 @@ func (_c *PaymentOrderCreate) createSpec() (*PaymentOrder, *sqlgraph.CreateSpec) _spec.SetField(paymentorder.FieldProviderInstanceID, field.TypeString, value) _node.ProviderInstanceID = &value } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = &value + } if value, ok := _c.mutation.Status(); ok { _spec.SetField(paymentorder.FieldStatus, field.TypeString, value) _node.Status = value @@ -1201,6 +1224,24 @@ func (u *PaymentOrderUpsert) ClearProviderInstanceID() *PaymentOrderUpsert { return u } +// SetProviderKey sets the "provider_key" field. +func (u *PaymentOrderUpsert) SetProviderKey(v string) *PaymentOrderUpsert { + u.Set(paymentorder.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PaymentOrderUpsert) UpdateProviderKey() *PaymentOrderUpsert { + u.SetExcluded(paymentorder.FieldProviderKey) + return u +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (u *PaymentOrderUpsert) ClearProviderKey() *PaymentOrderUpsert { + u.SetNull(paymentorder.FieldProviderKey) + return u +} + // SetStatus sets the "status" field. func (u *PaymentOrderUpsert) SetStatus(v string) *PaymentOrderUpsert { u.Set(paymentorder.FieldStatus, v) @@ -1880,6 +1921,27 @@ func (u *PaymentOrderUpsertOne) ClearProviderInstanceID() *PaymentOrderUpsertOne }) } +// SetProviderKey sets the "provider_key" field. +func (u *PaymentOrderUpsertOne) SetProviderKey(v string) *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PaymentOrderUpsertOne) UpdateProviderKey() *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.UpdateProviderKey() + }) +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (u *PaymentOrderUpsertOne) ClearProviderKey() *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.ClearProviderKey() + }) +} + // SetStatus sets the "status" field. func (u *PaymentOrderUpsertOne) SetStatus(v string) *PaymentOrderUpsertOne { return u.Update(func(s *PaymentOrderUpsert) { @@ -2770,6 +2832,27 @@ func (u *PaymentOrderUpsertBulk) ClearProviderInstanceID() *PaymentOrderUpsertBu }) } +// SetProviderKey sets the "provider_key" field. +func (u *PaymentOrderUpsertBulk) SetProviderKey(v string) *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PaymentOrderUpsertBulk) UpdateProviderKey() *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.UpdateProviderKey() + }) +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (u *PaymentOrderUpsertBulk) ClearProviderKey() *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.ClearProviderKey() + }) +} + // SetStatus sets the "status" field. func (u *PaymentOrderUpsertBulk) SetStatus(v string) *PaymentOrderUpsertBulk { return u.Update(func(s *PaymentOrderUpsert) { diff --git a/backend/ent/paymentorder_update.go b/backend/ent/paymentorder_update.go index 5978fc29..9a901415 100644 --- a/backend/ent/paymentorder_update.go +++ b/backend/ent/paymentorder_update.go @@ -385,6 +385,26 @@ func (_u *PaymentOrderUpdate) ClearProviderInstanceID() *PaymentOrderUpdate { return _u } +// SetProviderKey sets the "provider_key" field. +func (_u *PaymentOrderUpdate) SetProviderKey(v string) *PaymentOrderUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PaymentOrderUpdate) SetNillableProviderKey(v *string) *PaymentOrderUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (_u *PaymentOrderUpdate) ClearProviderKey() *PaymentOrderUpdate { + _u.mutation.ClearProviderKey() + return _u +} + // SetStatus sets the "status" field. func (_u *PaymentOrderUpdate) SetStatus(v string) *PaymentOrderUpdate { _u.mutation.SetStatus(v) @@ -776,6 +796,11 @@ func (_u *PaymentOrderUpdate) check() error { return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)} } } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := paymentorder.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)} + } + } if v, ok := _u.mutation.Status(); ok { if err := paymentorder.StatusValidator(v); err != nil { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.status": %w`, err)} @@ -910,6 +935,12 @@ func (_u *PaymentOrderUpdate) sqlSave(ctx context.Context) (_node int, err error if _u.mutation.ProviderInstanceIDCleared() { _spec.ClearField(paymentorder.FieldProviderInstanceID, field.TypeString) } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value) + } + if _u.mutation.ProviderKeyCleared() { + _spec.ClearField(paymentorder.FieldProviderKey, field.TypeString) + } if value, ok := _u.mutation.Status(); ok { _spec.SetField(paymentorder.FieldStatus, field.TypeString, value) } @@ -1399,6 +1430,26 @@ func (_u *PaymentOrderUpdateOne) ClearProviderInstanceID() *PaymentOrderUpdateOn return _u } +// SetProviderKey sets the "provider_key" field. +func (_u *PaymentOrderUpdateOne) SetProviderKey(v string) *PaymentOrderUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PaymentOrderUpdateOne) SetNillableProviderKey(v *string) *PaymentOrderUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (_u *PaymentOrderUpdateOne) ClearProviderKey() *PaymentOrderUpdateOne { + _u.mutation.ClearProviderKey() + return _u +} + // SetStatus sets the "status" field. func (_u *PaymentOrderUpdateOne) SetStatus(v string) *PaymentOrderUpdateOne { _u.mutation.SetStatus(v) @@ -1803,6 +1854,11 @@ func (_u *PaymentOrderUpdateOne) check() error { return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)} } } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := paymentorder.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)} + } + } if v, ok := _u.mutation.Status(); ok { if err := paymentorder.StatusValidator(v); err != nil { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.status": %w`, err)} @@ -1954,6 +2010,12 @@ func (_u *PaymentOrderUpdateOne) sqlSave(ctx context.Context) (_node *PaymentOrd if _u.mutation.ProviderInstanceIDCleared() { _spec.ClearField(paymentorder.FieldProviderInstanceID, field.TypeString) } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value) + } + if _u.mutation.ProviderKeyCleared() { + _spec.ClearField(paymentorder.FieldProviderKey, field.TypeString) + } if value, ok := _u.mutation.Status(); ok { _spec.SetField(paymentorder.FieldStatus, field.TypeString, value) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 268e9ddb..b7118ac9 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -723,38 +723,42 @@ func init() { paymentorderDescProviderInstanceID := paymentorderFields[18].Descriptor() // paymentorder.ProviderInstanceIDValidator is a validator for the "provider_instance_id" field. It is called by the builders before save. paymentorder.ProviderInstanceIDValidator = paymentorderDescProviderInstanceID.Validators[0].(func(string) error) + // paymentorderDescProviderKey is the schema descriptor for provider_key field. + paymentorderDescProviderKey := paymentorderFields[19].Descriptor() + // paymentorder.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + paymentorder.ProviderKeyValidator = paymentorderDescProviderKey.Validators[0].(func(string) error) // paymentorderDescStatus is the schema descriptor for status field. - paymentorderDescStatus := paymentorderFields[19].Descriptor() + paymentorderDescStatus := paymentorderFields[20].Descriptor() // paymentorder.DefaultStatus holds the default value on creation for the status field. paymentorder.DefaultStatus = paymentorderDescStatus.Default.(string) // paymentorder.StatusValidator is a validator for the "status" field. It is called by the builders before save. paymentorder.StatusValidator = paymentorderDescStatus.Validators[0].(func(string) error) // paymentorderDescRefundAmount is the schema descriptor for refund_amount field. - paymentorderDescRefundAmount := paymentorderFields[20].Descriptor() + paymentorderDescRefundAmount := paymentorderFields[21].Descriptor() // paymentorder.DefaultRefundAmount holds the default value on creation for the refund_amount field. paymentorder.DefaultRefundAmount = paymentorderDescRefundAmount.Default.(float64) // paymentorderDescForceRefund is the schema descriptor for force_refund field. - paymentorderDescForceRefund := paymentorderFields[23].Descriptor() + paymentorderDescForceRefund := paymentorderFields[24].Descriptor() // paymentorder.DefaultForceRefund holds the default value on creation for the force_refund field. paymentorder.DefaultForceRefund = paymentorderDescForceRefund.Default.(bool) // paymentorderDescRefundRequestedBy is the schema descriptor for refund_requested_by field. - paymentorderDescRefundRequestedBy := paymentorderFields[26].Descriptor() + paymentorderDescRefundRequestedBy := paymentorderFields[27].Descriptor() // paymentorder.RefundRequestedByValidator is a validator for the "refund_requested_by" field. It is called by the builders before save. paymentorder.RefundRequestedByValidator = paymentorderDescRefundRequestedBy.Validators[0].(func(string) error) // paymentorderDescClientIP is the schema descriptor for client_ip field. - paymentorderDescClientIP := paymentorderFields[32].Descriptor() + paymentorderDescClientIP := paymentorderFields[33].Descriptor() // paymentorder.ClientIPValidator is a validator for the "client_ip" field. It is called by the builders before save. paymentorder.ClientIPValidator = paymentorderDescClientIP.Validators[0].(func(string) error) // paymentorderDescSrcHost is the schema descriptor for src_host field. - paymentorderDescSrcHost := paymentorderFields[33].Descriptor() + paymentorderDescSrcHost := paymentorderFields[34].Descriptor() // paymentorder.SrcHostValidator is a validator for the "src_host" field. It is called by the builders before save. paymentorder.SrcHostValidator = paymentorderDescSrcHost.Validators[0].(func(string) error) // paymentorderDescCreatedAt is the schema descriptor for created_at field. - paymentorderDescCreatedAt := paymentorderFields[35].Descriptor() + paymentorderDescCreatedAt := paymentorderFields[36].Descriptor() // paymentorder.DefaultCreatedAt holds the default value on creation for the created_at field. paymentorder.DefaultCreatedAt = paymentorderDescCreatedAt.Default.(func() time.Time) // paymentorderDescUpdatedAt is the schema descriptor for updated_at field. - paymentorderDescUpdatedAt := paymentorderFields[36].Descriptor() + paymentorderDescUpdatedAt := paymentorderFields[37].Descriptor() // paymentorder.DefaultUpdatedAt holds the default value on creation for the updated_at field. paymentorder.DefaultUpdatedAt = paymentorderDescUpdatedAt.Default.(func() time.Time) // paymentorder.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. diff --git a/backend/ent/schema/payment_order.go b/backend/ent/schema/payment_order.go index a9576d2a..64378de1 100644 --- a/backend/ent/schema/payment_order.go +++ b/backend/ent/schema/payment_order.go @@ -91,6 +91,10 @@ func (PaymentOrder) Fields() []ent.Field { Optional(). Nillable(). MaxLen(64), + field.String("provider_key"). + Optional(). + Nillable(). + MaxLen(30), // 状态 field.String("status"). diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 519455f0..83bac21d 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -45,7 +45,7 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil { instanceProviderKey = inst.ProviderKey } - expectedProviderKey := expectedNotificationProviderKey(s.registry, o.PaymentType, instanceProviderKey) + expectedProviderKey := expectedNotificationProviderKey(s.registry, o.PaymentType, psStringValue(o.ProviderKey), instanceProviderKey) if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) { s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{ "expectedProvider": expectedProviderKey, @@ -69,10 +69,13 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo return s.toPaid(ctx, o, tradeNo, paid, pk) } -func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, instanceProviderKey string) string { +func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, orderProviderKey string, instanceProviderKey string) string { if key := strings.TrimSpace(instanceProviderKey); key != "" { return key } + if key := strings.TrimSpace(orderProviderKey); key != "" { + return key + } if registry != nil { if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(orderPaymentType))); key != "" { return key diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go index 4cc00301..712129b0 100644 --- a/backend/internal/service/payment_fulfillment_test.go +++ b/backend/internal/service/payment_fulfillment_test.go @@ -198,7 +198,7 @@ func TestExpectedNotificationProviderKeyPrefersOrderInstanceProvider(t *testing. assert.Equal(t, payment.TypeEasyPay, - expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay), + expectedNotificationProviderKey(registry, payment.TypeAlipay, "", payment.TypeEasyPay), ) } @@ -213,7 +213,7 @@ func TestExpectedNotificationProviderKeyUsesRegistryMappingForLegacyOrders(t *te assert.Equal(t, payment.TypeEasyPay, - expectedNotificationProviderKey(registry, payment.TypeAlipay, ""), + expectedNotificationProviderKey(registry, payment.TypeAlipay, "", ""), ) } @@ -222,6 +222,21 @@ func TestExpectedNotificationProviderKeyFallsBackToPaymentType(t *testing.T) { assert.Equal(t, payment.TypeWxpay, - expectedNotificationProviderKey(nil, payment.TypeWxpay, ""), + expectedNotificationProviderKey(nil, payment.TypeWxpay, "", ""), + ) +} + +func TestExpectedNotificationProviderKeyPrefersOrderSnapshotProviderKey(t *testing.T) { + t.Parallel() + + registry := payment.NewRegistry() + registry.Register(paymentFulfillmentTestProvider{ + key: payment.TypeAlipay, + supportedTypes: []payment.PaymentType{payment.TypeAlipay}, + }) + + assert.Equal(t, + payment.TypeEasyPay, + expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay, ""), ) } diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go index fa256be7..4b9b1872 100644 --- a/backend/internal/service/payment_order.go +++ b/backend/internal/service/payment_order.go @@ -251,7 +251,13 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err) return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error())) } - _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx) + _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID). + SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)). + SetNillablePayURL(psNilIfEmpty(pr.PayURL)). + SetNillableQrCode(psNilIfEmpty(pr.QRCode)). + SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)). + SetNillableProviderKey(psNilIfEmpty(sel.ProviderKey)). + Save(ctx) if err != nil { return nil, fmt.Errorf("update order with payment details: %w", err) } diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go index 80147180..f804eb8b 100644 --- a/backend/internal/service/payment_order_lifecycle.go +++ b/backend/internal/service/payment_order_lifecycle.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "strconv" + "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -241,7 +242,10 @@ func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentO if err == nil { cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID) if err == nil { - providerKey := s.registry.GetProviderKey(o.PaymentType) + providerKey := strings.TrimSpace(psStringValue(o.ProviderKey)) + if providerKey == "" { + providerKey = s.registry.GetProviderKey(o.PaymentType) + } if providerKey == "" { providerKey = o.PaymentType } @@ -255,3 +259,10 @@ func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentO s.EnsureProviders(ctx) return s.registry.GetProvider(o.PaymentType) } + +func psStringValue(value *string) string { + if value == nil { + return "" + } + return *value +} diff --git a/backend/migrations/112_add_payment_order_provider_key_snapshot.sql b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql new file mode 100644 index 00000000..7ec19ae3 --- /dev/null +++ b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql @@ -0,0 +1,10 @@ +ALTER TABLE payment_orders ADD COLUMN provider_key VARCHAR(30); + +UPDATE payment_orders +SET provider_key = ( + SELECT provider_key + FROM payment_provider_instances + WHERE CAST(id AS TEXT) = payment_orders.provider_instance_id +) +WHERE provider_key IS NULL + AND provider_instance_id IS NOT NULL; From 9bebf1c1a611e37c68a127ffee07a670460a6a09 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 20:53:46 +0800 Subject: [PATCH 038/326] feat: resolve payment results by resume token --- backend/internal/handler/payment_handler.go | 29 +++++ backend/internal/server/routes/payment.go | 1 + .../internal/service/payment_resume_lookup.go | 35 ++++++ .../service/payment_resume_lookup_test.go | 118 ++++++++++++++++++ frontend/src/api/payment.ts | 5 + frontend/src/views/user/PaymentResultView.vue | 12 +- .../user/__tests__/PaymentResultView.spec.ts | 26 ++++ 7 files changed, 225 insertions(+), 1 deletion(-) create mode 100644 backend/internal/service/payment_resume_lookup.go create mode 100644 backend/internal/service/payment_resume_lookup_test.go diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index d5c3d7c8..6440cdfd 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -357,6 +357,10 @@ type VerifyOrderRequest struct { OutTradeNo string `json:"out_trade_no" binding:"required"` } +type ResolveOrderByResumeTokenRequest struct { + ResumeToken string `json:"resume_token" binding:"required"` +} + // VerifyOrder actively queries the upstream payment provider to check // if payment was made, and processes it if so. // POST /api/v1/payment/orders/verify @@ -417,6 +421,31 @@ func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) { }) } +// ResolveOrderPublicByResumeToken resolves a payment order from a signed resume token. +// POST /api/v1/payment/public/orders/resolve +func (h *PaymentHandler) ResolveOrderPublicByResumeToken(c *gin.Context) { + var req ResolveOrderByResumeTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + order, err := h.paymentService.GetPublicOrderByResumeToken(c.Request.Context(), req.ResumeToken) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, PublicOrderResult{ + ID: order.ID, + OutTradeNo: order.OutTradeNo, + Amount: order.Amount, + PayAmount: order.PayAmount, + PaymentType: order.PaymentType, + OrderType: order.OrderType, + Status: order.Status, + }) +} + // requireAuth extracts the authenticated subject from the context. // Returns the subject and true on success; on failure it writes an Unauthorized response and returns false. func requireAuth(c *gin.Context) (middleware2.AuthSubject, bool) { diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go index 23bd58ad..dff14a70 100644 --- a/backend/internal/server/routes/payment.go +++ b/backend/internal/server/routes/payment.go @@ -49,6 +49,7 @@ func RegisterPaymentRoutes( public := v1.Group("/payment/public") { public.POST("/orders/verify", paymentHandler.VerifyOrderPublic) + public.POST("/orders/resolve", paymentHandler.ResolveOrderPublicByResumeToken) } // --- Webhook endpoints (no auth) --- diff --git a/backend/internal/service/payment_resume_lookup.go b/backend/internal/service/payment_resume_lookup.go new file mode 100644 index 00000000..493ca325 --- /dev/null +++ b/backend/internal/service/payment_resume_lookup.go @@ -0,0 +1,35 @@ +package service + +import ( + "context" + "fmt" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" +) + +func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) { + claims, err := s.paymentResume().ParseToken(strings.TrimSpace(token)) + if err != nil { + return nil, err + } + + order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID) + if err != nil { + return nil, fmt.Errorf("get order by resume token: %w", err) + } + if claims.UserID > 0 && order.UserID != claims.UserID { + return nil, fmt.Errorf("resume token user mismatch") + } + if claims.ProviderInstanceID != "" && strings.TrimSpace(psStringValue(order.ProviderInstanceID)) != claims.ProviderInstanceID { + return nil, fmt.Errorf("resume token provider instance mismatch") + } + if claims.ProviderKey != "" && strings.TrimSpace(psStringValue(order.ProviderKey)) != claims.ProviderKey { + return nil, fmt.Errorf("resume token provider key mismatch") + } + if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType { + return nil, fmt.Errorf("resume token payment type mismatch") + } + + return order, nil +} diff --git a/backend/internal/service/payment_resume_lookup_test.go b/backend/internal/service/payment_resume_lookup_test.go new file mode 100644 index 00000000..d411398e --- /dev/null +++ b/backend/internal/service/payment_resume_lookup_test.go @@ -0,0 +1,118 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/stretchr/testify/require" +) + +func TestGetPublicOrderByResumeTokenReturnsMatchingOrder(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("resume@example.com"). + SetPasswordHash("hash"). + SetUsername("resume-user"). + Save(ctx) + require.NoError(t, err) + + instanceID := "12" + providerKey := payment.TypeEasyPay + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("RESUME-ORDER"). + SetOutTradeNo("sub2_resume_lookup"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-1"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID(instanceID). + SetProviderKey(providerKey). + Save(ctx) + require.NoError(t, err) + + resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := resumeSvc.CreateToken(ResumeTokenClaims{ + OrderID: order.ID, + UserID: user.ID, + ProviderInstanceID: instanceID, + ProviderKey: providerKey, + PaymentType: payment.TypeAlipay, + CanonicalReturnURL: "https://app.example.com/payment/result", + }) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + resumeService: resumeSvc, + } + + got, err := svc.GetPublicOrderByResumeToken(ctx, token) + require.NoError(t, err) + require.Equal(t, order.ID, got.ID) +} + +func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("resume-mismatch@example.com"). + SetPasswordHash("hash"). + SetUsername("resume-mismatch-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("RESUME-MISMATCH"). + SetOutTradeNo("sub2_resume_lookup_mismatch"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-2"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID("12"). + SetProviderKey(payment.TypeEasyPay). + Save(ctx) + require.NoError(t, err) + + resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := resumeSvc.CreateToken(ResumeTokenClaims{ + OrderID: order.ID, + UserID: user.ID, + ProviderInstanceID: "99", + ProviderKey: payment.TypeEasyPay, + PaymentType: payment.TypeAlipay, + CanonicalReturnURL: "https://app.example.com/payment/result", + }) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + resumeService: resumeSvc, + } + + _, err = svc.GetPublicOrderByResumeToken(ctx, token) + require.Error(t, err) + require.Contains(t, err.Error(), "resume token") +} diff --git a/frontend/src/api/payment.ts b/frontend/src/api/payment.ts index 5cedb107..91b16866 100644 --- a/frontend/src/api/payment.ts +++ b/frontend/src/api/payment.ts @@ -72,6 +72,11 @@ export const paymentAPI = { return apiClient.post('/payment/public/orders/verify', { out_trade_no: outTradeNo }) }, + /** Resolve an order from a signed resume token without auth */ + resolveOrderPublicByResumeToken(resumeToken: string) { + return apiClient.post('/payment/public/orders/resolve', { resume_token: resumeToken }) + }, + /** Request a refund for a completed order */ requestRefund(id: number, data: { reason: string }) { return apiClient.post(`/payment/orders/${id}/refund-request`, data) diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue index e1db3ce2..9687d1c7 100644 --- a/frontend/src/views/user/PaymentResultView.vue +++ b/frontend/src/views/user/PaymentResultView.vue @@ -150,7 +150,17 @@ onMounted(async () => { } } - if (orderId) { + if (!order.value && !orderId && resumeToken) { + try { + const result = await paymentAPI.resolveOrderPublicByResumeToken(resumeToken) + order.value = result.data + orderId = result.data.id + } catch (_err: unknown) { + // Resume token recovery failed, continue to legacy fallback paths. + } + } + + if (!order.value && orderId) { try { order.value = await paymentStore.pollOrderStatus(orderId) } catch (_err: unknown) { diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts index b06217ab..d23a60d9 100644 --- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts +++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts @@ -9,6 +9,7 @@ const routerPush = vi.hoisted(() => vi.fn()) const pollOrderStatus = vi.hoisted(() => vi.fn()) const verifyOrderPublic = vi.hoisted(() => vi.fn()) const verifyOrder = vi.hoisted(() => vi.fn()) +const resolveOrderPublicByResumeToken = vi.hoisted(() => vi.fn()) vi.mock('vue-router', async () => { const actual = await vi.importActual('vue-router') @@ -39,6 +40,7 @@ vi.mock('@/api/payment', () => ({ paymentAPI: { verifyOrderPublic, verifyOrder, + resolveOrderPublicByResumeToken, }, })) @@ -67,6 +69,7 @@ describe('PaymentResultView', () => { pollOrderStatus.mockReset() verifyOrderPublic.mockReset() verifyOrder.mockReset() + resolveOrderPublicByResumeToken.mockReset() window.localStorage.clear() }) @@ -129,4 +132,27 @@ describe('PaymentResultView', () => { expect(verifyOrderPublic).toHaveBeenCalledWith('legacy-123') expect(wrapper.text()).toContain('payment.result.success') }) + + it('resolves order by resume token when local recovery snapshot is missing', async () => { + routeState.query = { + resume_token: 'resume-77', + } + resolveOrderPublicByResumeToken.mockResolvedValue({ + data: orderFactory('PAID'), + }) + + const wrapper = mount(PaymentResultView, { + global: { + stubs: { + OrderStatusBadge: true, + }, + }, + }) + + await flushPromises() + + expect(resolveOrderPublicByResumeToken).toHaveBeenCalledWith('resume-77') + expect(wrapper.text()).toContain('payment.result.success') + expect(verifyOrderPublic).not.toHaveBeenCalled() + }) }) From 32059ae9d5fcf7024fadee26cf1e9c32dcc6da4b Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 20:58:19 +0800 Subject: [PATCH 039/326] fix: backfill email identities on successful login --- .../internal/service/auth_oauth_email_flow.go | 6 +++ backend/internal/service/auth_service.go | 8 ++++ .../auth_service_identity_sync_test.go | 37 +++++++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go index ca3403d4..4ab4e245 100644 --- a/backend/internal/service/auth_oauth_email_flow.go +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -147,5 +147,11 @@ func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, pa // RecordSuccessfulLogin updates last-login activity after a non-standard login // flow finishes with a real session. func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) { + if s != nil && s.userRepo != nil && userID > 0 { + user, err := s.userRepo.GetByID(ctx, userID) + if err == nil { + s.backfillEmailIdentityOnSuccessfulLogin(ctx, user) + } + } s.touchUserLogin(ctx, userID) } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 40753139..dda6df04 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -430,6 +430,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string if !user.IsActive() { return "", nil, ErrUserNotActive } + s.backfillEmailIdentityOnSuccessfulLogin(ctx, user) s.touchUserLogin(ctx, user.ID) // 生成JWT token @@ -802,6 +803,13 @@ func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) { } } +func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context, user *User) { + if s == nil || user == nil || user.ID <= 0 { + return + } + s.ensureEmailAuthIdentity(ctx, user) +} + func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) { if s == nil || s.entClient == nil || user == nil || user.ID <= 0 { return diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go index 5bd2b25d..fcb4813b 100644 --- a/backend/internal/service/auth_service_identity_sync_test.go +++ b/backend/internal/service/auth_service_identity_sync_test.go @@ -150,4 +150,41 @@ func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) { require.NotNil(t, storedUser.LastActiveAt) require.True(t, storedUser.LastLoginAt.After(old)) require.True(t, storedUser.LastActiveAt.After(old)) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("login@example.com"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) +} + +func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) { + svc, repo, client := newAuthServiceWithEnt(t) + ctx := context.Background() + + user := &service.User{ + Email: "record@example.com", + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 1, + Concurrency: 1, + } + require.NoError(t, user.SetPassword("password")) + require.NoError(t, repo.Create(ctx, user)) + + svc.RecordSuccessfulLogin(ctx, user.ID) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("record@example.com"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) } From bdcd3d87e530fc1bbbec4888a51ae79c1cb0e298 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 21:09:38 +0800 Subject: [PATCH 040/326] fix: resolve unique legacy payment providers --- .../service/payment_order_lifecycle.go | 46 +++++--- backend/internal/service/payment_refund.go | 79 ++++++++++++- .../service/payment_webhook_provider.go | 22 ++-- .../service/payment_webhook_provider_test.go | 109 ++++++++++++++++++ 4 files changed, 220 insertions(+), 36 deletions(-) diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go index f804eb8b..24d8b7a2 100644 --- a/backend/internal/service/payment_order_lifecycle.go +++ b/backend/internal/service/payment_order_lifecycle.go @@ -5,7 +5,6 @@ import ( "fmt" "log/slog" "strconv" - "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -237,29 +236,38 @@ func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) // getOrderProvider creates a provider using the order's original instance config. // Falls back to registry lookup if instance ID is missing (legacy orders). func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { - if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" { - instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64) - if err == nil { - cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID) - if err == nil { - providerKey := strings.TrimSpace(psStringValue(o.ProviderKey)) - if providerKey == "" { - providerKey = s.registry.GetProviderKey(o.PaymentType) - } - if providerKey == "" { - providerKey = o.PaymentType - } - p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg) - if err == nil { - return p, nil - } - } - } + inst, err := s.getOrderProviderInstance(ctx, o) + if err != nil { + return nil, fmt.Errorf("load order provider instance: %w", err) + } + if inst != nil { + return s.createProviderFromInstance(ctx, inst) } s.EnsureProviders(ctx) return s.registry.GetProvider(o.PaymentType) } +func (s *PaymentService) createProviderFromInstance(ctx context.Context, inst *dbent.PaymentProviderInstance) (payment.Provider, error) { + if inst == nil { + return nil, fmt.Errorf("payment provider instance is missing") + } + + cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID)) + if err != nil { + return nil, fmt.Errorf("load provider instance config: %w", err) + } + if inst.PaymentMode != "" { + cfg["paymentMode"] = inst.PaymentMode + } + + instID := strconv.FormatInt(int64(inst.ID), 10) + prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg) + if err != nil { + return nil, fmt.Errorf("create provider from instance: %w", err) + } + return prov, nil +} + func psStringValue(value *string) string { if value == nil { return "" diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go index c5bda763..01eecff8 100644 --- a/backend/internal/service/payment_refund.go +++ b/backend/internal/service/payment_refund.go @@ -12,6 +12,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" "github.com/Wei-Shaw/sub2api/internal/payment" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) @@ -19,18 +20,90 @@ import ( // --- Refund Flow --- // getOrderProviderInstance looks up the provider instance that processed this order. -// Returns nil, nil for legacy orders without provider_instance_id. +// For legacy orders without provider_instance_id, it resolves only when the +// enabled instance is uniquely identifiable from the stored order fields. func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { - if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" { + if s == nil || s.entClient == nil || o == nil { return nil, nil } - instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64) + + instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID)) + if instIDStr == "" { + return s.resolveUniqueLegacyOrderProviderInstance(ctx, o) + } + + instID, err := strconv.ParseInt(instIDStr, 10, 64) if err != nil { return nil, nil } return s.entClient.PaymentProviderInstance.Get(ctx, instID) } +func (s *PaymentService) resolveUniqueLegacyOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { + providerKey := strings.TrimSpace(psStringValue(o.ProviderKey)) + if providerKey != "" { + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where( + paymentproviderinstance.EnabledEQ(true), + paymentproviderinstance.ProviderKeyEQ(providerKey), + ). + All(ctx) + if err != nil { + return nil, err + } + if len(instances) == 1 { + return instances[0], nil + } + return nil, nil + } + + paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType)) + if paymentType == "" { + return nil, nil + } + + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where(paymentproviderinstance.EnabledEQ(true)). + All(ctx) + if err != nil { + return nil, err + } + + var matched []*dbent.PaymentProviderInstance + for _, inst := range instances { + if psLegacyOrderMatchesInstance(paymentType, inst) { + matched = append(matched, inst) + } + } + if len(matched) == 1 { + return matched[0], nil + } + return nil, nil +} + +func psLegacyOrderMatchesInstance(orderPaymentType string, inst *dbent.PaymentProviderInstance) bool { + if inst == nil { + return false + } + + baseType := payment.GetBasePaymentType(strings.TrimSpace(orderPaymentType)) + instanceProviderKey := strings.TrimSpace(inst.ProviderKey) + if baseType == "" { + return false + } + + if baseType == payment.TypeStripe { + return instanceProviderKey == payment.TypeStripe + } + if instanceProviderKey == payment.TypeStripe { + return false + } + if instanceProviderKey == baseType { + return true + } + return payment.InstanceSupportsType(inst.SupportedTypes, baseType) +} + func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error { o, err := s.validateRefundRequest(ctx, oid, uid) if err != nil { diff --git a/backend/internal/service/payment_webhook_provider.go b/backend/internal/service/payment_webhook_provider.go index a877db2b..289d63ed 100644 --- a/backend/internal/service/payment_webhook_provider.go +++ b/backend/internal/service/payment_webhook_provider.go @@ -4,14 +4,12 @@ import ( "context" "fmt" "log/slog" - "strconv" "strings" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" "github.com/Wei-Shaw/sub2api/internal/payment" - "github.com/Wei-Shaw/sub2api/internal/payment/provider" ) // GetWebhookProvider returns the provider instance that should verify a webhook. @@ -24,6 +22,13 @@ func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, ou if psHasPinnedProviderInstance(order) { return s.getPinnedOrderProvider(ctx, order) } + inst, err := s.getOrderProviderInstance(ctx, order) + if err != nil { + return nil, fmt.Errorf("load order provider instance: %w", err) + } + if inst != nil { + return s.createProviderFromInstance(ctx, inst) + } if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey) } @@ -48,18 +53,7 @@ func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.Pa if inst == nil { return nil, fmt.Errorf("order %d provider instance is missing", o.ID) } - - instID := strconv.FormatInt(int64(inst.ID), 10) - cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID)) - if err != nil { - return nil, fmt.Errorf("load provider instance config: %w", err) - } - - prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg) - if err != nil { - return nil, fmt.Errorf("create pinned provider: %w", err) - } - return prov, nil + return s.createProviderFromInstance(ctx, inst) } func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, providerKey string) bool { diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go index 85c296de..33e4186d 100644 --- a/backend/internal/service/payment_webhook_provider_test.go +++ b/backend/internal/service/payment_webhook_provider_test.go @@ -4,13 +4,17 @@ package service import ( "context" + "encoding/json" "testing" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/stretchr/testify/require" ) +const webhookProviderTestEncryptionKey = "0123456789abcdef0123456789abcdef" + type webhookProviderTestDouble struct { key string types []payment.PaymentType @@ -32,6 +36,111 @@ func (p webhookProviderTestDouble) Refund(context.Context, payment.RefundRequest panic("unexpected call") } +func encryptWebhookProviderConfig(t *testing.T, config map[string]string) string { + t.Helper() + + data, err := json.Marshal(config) + require.NoError(t, err) + + encrypted, err := payment.Encrypt(string(data), []byte(webhookProviderTestEncryptionKey)) + require.NoError(t, err) + return encrypted +} + +func newWebhookProviderTestLoadBalancer(client *dbent.Client) payment.LoadBalancer { + return payment.NewDefaultLoadBalancer(client, []byte(webhookProviderTestEncryptionKey)) +} + +func TestGetOrderProviderInstanceResolvesUniqueLegacyProviderKey(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + inst, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-a"). + SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_test_legacy_provider_key"})). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + providerKey := payment.TypeStripe + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeStripe, + ProviderKey: &providerKey, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, inst.ID, got.ID) +} + +func TestGetOrderProviderInstanceResolvesUniqueLegacyPaymentType(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + inst, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpayDirect, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, inst.ID, got.ID) +} + +func TestGetOrderProviderInstanceLeavesAmbiguousLegacyOrderUnresolved(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("easypay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpay, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.Nil(t, got) +} + func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) { ctx := context.Background() client := newPaymentConfigServiceTestClient(t) From 5adefb466b9c7a1cd02e2c9f1c0221c08d38e5bf Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 21:24:33 +0800 Subject: [PATCH 041/326] fix: finalize oauth identity bindings --- .../handler/auth_oauth_pending_flow.go | 152 +++++++++++++++++- .../handler/auth_oauth_pending_flow_test.go | 11 +- backend/internal/handler/auth_wechat_oauth.go | 96 +++++++++++ .../handler/auth_wechat_oauth_test.go | 124 ++++++++++++++ 4 files changed, 376 insertions(+), 7 deletions(-) diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 6d6564e8..21ed2bc6 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -10,6 +10,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" dbuser "github.com/Wei-Shaw/sub2api/ent/user" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" @@ -309,6 +310,14 @@ func cloneOAuthMetadata(values map[string]any) map[string]any { return cloned } +func mergeOAuthMetadata(base map[string]any, overlay map[string]any) map[string]any { + merged := cloneOAuthMetadata(base) + for key, value := range overlay { + merged[key] = value + } + return merged +} + func normalizeAdoptedOAuthDisplayName(value string) string { value = strings.TrimSpace(value) if len([]rune(value)) > 100 { @@ -558,6 +567,10 @@ func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string { } func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) { + if session != nil && strings.EqualFold(strings.TrimSpace(session.ProviderType), "wechat") { + return ensurePendingWeChatOAuthIdentityForUser(ctx, tx, session, userID) + } + client := tx.Client() identity, err := client.AuthIdentity.Query(). Where( @@ -588,14 +601,149 @@ func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, sessio return create.Save(ctx) } +func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) { + client := tx.Client() + providerType := strings.TrimSpace(session.ProviderType) + providerKey := strings.TrimSpace(session.ProviderKey) + providerSubject := strings.TrimSpace(session.ProviderSubject) + channel := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel")) + channelAppID := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_app_id")) + channelSubject := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_subject")) + metadata := cloneOAuthMetadata(session.UpstreamIdentityClaims) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderSubjectEQ(providerSubject), + ). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if identity != nil && identity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + + var legacyOpenIDIdentity *dbent.AuthIdentity + if channelSubject != "" && channelSubject != providerSubject { + legacyOpenIDIdentity, err = client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderSubjectEQ(channelSubject), + ). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if legacyOpenIDIdentity != nil && legacyOpenIDIdentity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + } + + switch { + case identity != nil: + update := client.AuthIdentity.UpdateOneID(identity.ID). + SetMetadata(mergeOAuthMetadata(identity.Metadata, metadata)) + if issuer := oauthIdentityIssuer(session); issuer != nil { + update = update.SetIssuer(strings.TrimSpace(*issuer)) + } + identity, err = update.Save(ctx) + if err != nil { + return nil, err + } + case legacyOpenIDIdentity != nil: + update := client.AuthIdentity.UpdateOneID(legacyOpenIDIdentity.ID). + SetProviderSubject(providerSubject). + SetMetadata(mergeOAuthMetadata(legacyOpenIDIdentity.Metadata, metadata)) + if issuer := oauthIdentityIssuer(session); issuer != nil { + update = update.SetIssuer(strings.TrimSpace(*issuer)) + } + identity, err = update.Save(ctx) + if err != nil { + return nil, err + } + default: + create := client.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetProviderSubject(providerSubject). + SetMetadata(metadata) + if issuer := oauthIdentityIssuer(session); issuer != nil { + create = create.SetIssuer(strings.TrimSpace(*issuer)) + } + identity, err = create.Save(ctx) + if err != nil { + return nil, err + } + } + + if channel == "" || channelAppID == "" || channelSubject == "" { + return identity, nil + } + + channelRecord, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(providerType), + authidentitychannel.ProviderKeyEQ(providerKey), + authidentitychannel.ChannelEQ(channel), + authidentitychannel.ChannelAppIDEQ(channelAppID), + authidentitychannel.ChannelSubjectEQ(channelSubject), + ). + WithIdentity(). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if channelRecord != nil && channelRecord.Edges.Identity != nil && channelRecord.Edges.Identity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + + channelMetadata := mergeOAuthMetadata(channelRecordMetadata(channelRecord), metadata) + if channelRecord == nil { + if _, err := client.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetChannel(channel). + SetChannelAppID(channelAppID). + SetChannelSubject(channelSubject). + SetMetadata(channelMetadata). + Save(ctx); err != nil { + return nil, err + } + return identity, nil + } + + _, err = client.AuthIdentityChannel.UpdateOneID(channelRecord.ID). + SetIdentityID(identity.ID). + SetMetadata(channelMetadata). + Save(ctx) + if err != nil { + return nil, err + } + return identity, nil +} + +func channelRecordMetadata(channel *dbent.AuthIdentityChannel) map[string]any { + if channel == nil { + return map[string]any{} + } + return cloneOAuthMetadata(channel.Metadata) +} + func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision) bool { if session == nil || decision == nil { return false } - if strings.EqualFold(strings.TrimSpace(session.Intent), "bind_current_user") { + switch strings.ToLower(strings.TrimSpace(session.Intent)) { + case "bind_current_user", "login", "adopt_existing_user_by_email": return true + default: + return decision.AdoptDisplayName || decision.AdoptAvatar } - return decision.AdoptDisplayName || decision.AdoptAvatar } func applyPendingOAuthBinding( diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index ae506e52..89accd60 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -372,7 +372,7 @@ func TestExchangePendingOAuthCompletionBindCurrentUserOwnershipConflict(t *testi require.Nil(t, storedSession.ConsumedAt) } -func TestExchangePendingOAuthCompletionLoginFalseFalseDoesNotBindIdentity(t *testing.T) { +func TestExchangePendingOAuthCompletionLoginFalseFalseBindsIdentityWithoutAdoption(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() @@ -420,21 +420,22 @@ func TestExchangePendingOAuthCompletionLoginFalseFalseDoesNotBindIdentity(t *tes require.Equal(t, http.StatusOK, recorder.Code) - identityCount, err := client.AuthIdentity.Query(). + identity, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("linuxdo"), authidentity.ProviderKeyEQ("linuxdo"), authidentity.ProviderSubjectEQ("login-false-123"), ). - Count(ctx) + Only(ctx) require.NoError(t, err) - require.Zero(t, identityCount) + require.Equal(t, userEntity.ID, identity.UserID) decision, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). Only(ctx) require.NoError(t, err) - require.Nil(t, decision.IdentityID) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) require.False(t, decision.AdoptDisplayName) require.False(t, decision.AdoptAvatar) diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index f0755f1f..b6d47670 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -242,7 +242,18 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) return } + if existingIdentityUser == nil { + existingIdentityUser, err = h.findWeChatUserByLegacyOpenID(c.Request.Context(), identityRef, cfg, openid) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + } if existingIdentityUser != nil { + if err := h.ensureWeChatRuntimeIdentityBinding(c.Request.Context(), existingIdentityUser.ID, identityRef, upstreamClaims); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") if err != nil { redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) @@ -511,6 +522,91 @@ func (h *AuthHandler) ensureWeChatBindOwnership( return nil } +func (h *AuthHandler) findWeChatUserByLegacyOpenID( + ctx context.Context, + identity service.PendingAuthIdentityKey, + cfg wechatOAuthConfig, + openid string, +) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + openid = strings.TrimSpace(openid) + channel := strings.TrimSpace(cfg.mode) + channelAppID := strings.TrimSpace(cfg.appID) + if openid != "" && channel != "" && channelAppID != "" { + record, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)), + authidentitychannel.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)), + authidentitychannel.ChannelEQ(channel), + authidentitychannel.ChannelAppIDEQ(channelAppID), + authidentitychannel.ChannelSubjectEQ(openid), + ). + WithIdentity(func(q *dbent.AuthIdentityQuery) { + q.WithUser() + }). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) + } + if record != nil && record.Edges.Identity != nil && record.Edges.Identity.Edges.User != nil { + return record.Edges.Identity.Edges.User, nil + } + } + + if openid == "" { + return nil, nil + } + + record, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)), + authidentity.ProviderSubjectEQ(openid), + ). + WithUser(). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + return record.Edges.User, nil +} + +func (h *AuthHandler) ensureWeChatRuntimeIdentityBinding( + ctx context.Context, + userID int64, + identity service.PendingAuthIdentityKey, + upstreamClaims map[string]any, +) error { + client := h.entClient() + if client == nil { + return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + tx, err := client.Tx(ctx) + if err != nil { + return infraerrors.InternalServer("AUTH_IDENTITY_BIND_FAILED", "failed to begin wechat identity repair transaction").WithCause(err) + } + defer func() { _ = tx.Rollback() }() + + _, err = ensurePendingOAuthIdentityForUser(dbent.NewTxContext(ctx, tx), tx, &dbent.PendingAuthSession{ + ProviderType: strings.TrimSpace(identity.ProviderType), + ProviderKey: strings.TrimSpace(identity.ProviderKey), + ProviderSubject: strings.TrimSpace(identity.ProviderSubject), + UpstreamIdentityClaims: cloneOAuthMetadata(upstreamClaims), + }, userID) + if err != nil { + return err + } + return tx.Commit() +} + func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) { mode, err := resolveWeChatOAuthMode(rawMode, c) if err != nil { diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index 1ff80e1b..dd022fb9 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -15,6 +15,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" @@ -563,6 +564,19 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing require.Equal(t, "WeChat Display", identity.Metadata["display_name"]) require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"]) + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ("wechat-main"), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open-app"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, identity.ID, channel.IdentityID) + require.Equal(t, "union-456", channel.Metadata["unionid"]) + decision, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)). Only(ctx) @@ -579,6 +593,116 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing require.NotNil(t, consumed.ConsumedAt) } +func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy WeChat","headimgurl":"https://cdn.example/legacy.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + legacyUser, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + legacyIdentity, err := client.AuthIdentity.Create(). + SetUserID(legacyUser.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("openid-123"). + SetMetadata(map[string]any{"openid": "openid-123"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, session.TargetUserID) + require.Equal(t, legacyUser.ID, *session.TargetUserID) + require.Equal(t, legacyUser.Email, session.ResolvedEmail) + + repairedIdentity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("union-456"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, legacyIdentity.ID, repairedIdentity.ID) + require.Equal(t, legacyUser.ID, repairedIdentity.UserID) + + openIDIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("openid-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, openIDIdentityCount) + + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open-app"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, repairedIdentity.ID, channel.IdentityID) +} + func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { t.Helper() From f65429145e57c7bb61844fc978a3ef92e53710ef Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 21:31:05 +0800 Subject: [PATCH 042/326] fix: route legacy linuxdo users to account binding --- .../internal/handler/auth_linuxdo_oauth.go | 59 ++++++++++++++ .../handler/auth_linuxdo_oauth_test.go | 76 +++++++++++++++++++ 2 files changed, 135 insertions(+) diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index c3ec3804..a0760a3b 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -15,6 +15,8 @@ import ( "time" "unicode/utf8" + dbent "github.com/Wei-Shaw/sub2api/ent" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" @@ -237,6 +239,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "") return } + compatEmail := strings.TrimSpace(email) // 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。 // 统一使用基于 subject 的稳定合成邮箱来做账号绑定。 @@ -255,6 +258,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { "suggested_display_name": displayName, "suggested_avatar_url": avatarURL, } + if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) { + upstreamClaims["compat_email"] = compatEmail + } if intent == oauthIntentBindCurrentUser { targetUserID, err := h.readOAuthBindUserIDFromCookie(c, linuxDoOAuthBindUserCookieName) if err != nil { @@ -314,6 +320,33 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { return } + compatEmailUser, err := h.findLinuxDoCompatEmailUser(c.Request.Context(), compatEmail) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if compatEmailUser != nil { + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: "adopt_existing_user_by_email", + Identity: identityKey, + TargetUserID: &compatEmailUser.ID, + ResolvedEmail: compatEmailUser.Email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "redirect": redirectTo, + "step": "bind_login_required", + "email": compatEmailUser.Email, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + if h.isForceEmailOnThirdPartySignup(c.Request.Context()) { if err := h.createOAuthEmailRequiredPendingSession(c, identityKey, redirectTo, browserSessionKey, upstreamClaims); err != nil { redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") @@ -372,6 +405,32 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { redirectToFrontendCallback(c, frontendCallback) } +func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + email = strings.TrimSpace(strings.ToLower(email)) + if email == "" || + strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) { + return nil, nil + } + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEqualFold(email)). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err) + } + return userEntity, nil +} + type completeLinuxDoOAuthRequest struct { InvitationCode string `json:"invitation_code" binding:"required"` AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index 765779b5..fb57e570 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -300,6 +300,82 @@ func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingUser(t *testin require.Nil(t, completion["error"]) } +func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"321","email":"legacy@example.com","username":"linuxdo_user","name":"LinuxDo Display","avatar_url":"https://cdn.example/linuxdo.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + defer client.Close() + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-compat&state=state-compat", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-compat")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-compat")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "adopt_existing_user_by_email", session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, existingUser.Email, session.ResolvedEmail) + require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"]) + + completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, "/dashboard", completion["redirect"]) + require.Equal(t, "bind_login_required", completion["step"]) + require.Equal(t, existingUser.Email, completion["email"]) + _, hasAccessToken := completion["access_token"] + require.False(t, hasAccessToken) +} + func TestLinuxDoOAuthCallbackCreatesInvitationPendingSessionWhenSignupRequiresInvite(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { From 422f60a14513b8c9df33bdf8209d82fb6888d7b9 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 21:42:35 +0800 Subject: [PATCH 043/326] fix: normalize legacy wechat auth identity keys --- .../handler/auth_oauth_pending_flow.go | 109 ++++++++-- backend/internal/handler/auth_wechat_oauth.go | 122 ++++++++--- .../handler/auth_wechat_oauth_test.go | 192 ++++++++++++++++++ ...3_normalize_legacy_wechat_provider_key.sql | 89 ++++++++ 4 files changed, 464 insertions(+), 48 deletions(-) create mode 100644 backend/migrations/113_normalize_legacy_wechat_provider_key.sql diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 21ed2bc6..fd35e4e5 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -606,39 +606,42 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, providerType := strings.TrimSpace(session.ProviderType) providerKey := strings.TrimSpace(session.ProviderKey) providerSubject := strings.TrimSpace(session.ProviderSubject) + providerKeys := wechatCompatibleProviderKeys(providerKey) channel := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel")) channelAppID := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_app_id")) channelSubject := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_subject")) metadata := cloneOAuthMetadata(session.UpstreamIdentityClaims) - identity, err := client.AuthIdentity.Query(). + identityRecords, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ(providerType), - authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderKeyIn(providerKeys...), authidentity.ProviderSubjectEQ(providerSubject), ). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return nil, err } - if identity != nil && identity.UserID != userID { - return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + identity, hasCanonicalKey, err := chooseWeChatIdentityForUser(identityRecords, userID, providerKey) + if err != nil { + return nil, err } var legacyOpenIDIdentity *dbent.AuthIdentity if channelSubject != "" && channelSubject != providerSubject { - legacyOpenIDIdentity, err = client.AuthIdentity.Query(). + legacyOpenIDRecords, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ(providerType), - authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderKeyIn(providerKeys...), authidentity.ProviderSubjectEQ(channelSubject), ). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return nil, err } - if legacyOpenIDIdentity != nil && legacyOpenIDIdentity.UserID != userID { - return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + legacyOpenIDIdentity, _, err = chooseWeChatIdentityForUser(legacyOpenIDRecords, userID, providerKey) + if err != nil { + return nil, err } } @@ -646,6 +649,9 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, case identity != nil: update := client.AuthIdentity.UpdateOneID(identity.ID). SetMetadata(mergeOAuthMetadata(identity.Metadata, metadata)) + if !strings.EqualFold(strings.TrimSpace(identity.ProviderKey), providerKey) && !hasCanonicalKey { + update = update.SetProviderKey(providerKey) + } if issuer := oauthIdentityIssuer(session); issuer != nil { update = update.SetIssuer(strings.TrimSpace(*issuer)) } @@ -655,6 +661,7 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, } case legacyOpenIDIdentity != nil: update := client.AuthIdentity.UpdateOneID(legacyOpenIDIdentity.ID). + SetProviderKey(providerKey). SetProviderSubject(providerSubject). SetMetadata(mergeOAuthMetadata(legacyOpenIDIdentity.Metadata, metadata)) if issuer := oauthIdentityIssuer(session); issuer != nil { @@ -684,21 +691,22 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, return identity, nil } - channelRecord, err := client.AuthIdentityChannel.Query(). + channelRecords, err := client.AuthIdentityChannel.Query(). Where( authidentitychannel.ProviderTypeEQ(providerType), - authidentitychannel.ProviderKeyEQ(providerKey), + authidentitychannel.ProviderKeyIn(providerKeys...), authidentitychannel.ChannelEQ(channel), authidentitychannel.ChannelAppIDEQ(channelAppID), authidentitychannel.ChannelSubjectEQ(channelSubject), ). WithIdentity(). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return nil, err } - if channelRecord != nil && channelRecord.Edges.Identity != nil && channelRecord.Edges.Identity.UserID != userID { - return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + channelRecord, hasCanonicalChannelKey, err := chooseWeChatChannelForUser(channelRecords, userID, providerKey) + if err != nil { + return nil, err } channelMetadata := mergeOAuthMetadata(channelRecordMetadata(channelRecord), metadata) @@ -717,16 +725,75 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, return identity, nil } - _, err = client.AuthIdentityChannel.UpdateOneID(channelRecord.ID). + updateChannel := client.AuthIdentityChannel.UpdateOneID(channelRecord.ID). SetIdentityID(identity.ID). - SetMetadata(channelMetadata). - Save(ctx) + SetMetadata(channelMetadata) + if !strings.EqualFold(strings.TrimSpace(channelRecord.ProviderKey), providerKey) && !hasCanonicalChannelKey { + updateChannel = updateChannel.SetProviderKey(providerKey) + } + _, err = updateChannel.Save(ctx) if err != nil { return nil, err } return identity, nil } +func chooseWeChatIdentityForUser(records []*dbent.AuthIdentity, userID int64, preferredProviderKey string) (*dbent.AuthIdentity, bool, error) { + var preferred *dbent.AuthIdentity + var fallback *dbent.AuthIdentity + hasCanonicalKey := false + for _, record := range records { + if record == nil { + continue + } + if record.UserID != userID { + return nil, false, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) { + hasCanonicalKey = true + if preferred == nil { + preferred = record + } + continue + } + if fallback == nil { + fallback = record + } + } + if preferred != nil { + return preferred, hasCanonicalKey, nil + } + return fallback, hasCanonicalKey, nil +} + +func chooseWeChatChannelForUser(records []*dbent.AuthIdentityChannel, userID int64, preferredProviderKey string) (*dbent.AuthIdentityChannel, bool, error) { + var preferred *dbent.AuthIdentityChannel + var fallback *dbent.AuthIdentityChannel + hasCanonicalKey := false + for _, record := range records { + if record == nil { + continue + } + if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID { + return nil, false, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) { + hasCanonicalKey = true + if preferred == nil { + preferred = record + } + continue + } + if fallback == nil { + fallback = record + } + } + if preferred != nil { + return preferred, hasCanonicalKey, nil + } + return fallback, hasCanonicalKey, nil +} + func channelRecordMetadata(channel *dbent.AuthIdentityChannel) map[string]any { if channel == nil { return map[string]any{} diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index b6d47670..95993dfc 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -34,6 +34,7 @@ const ( wechatOAuthDefaultRedirectTo = "/dashboard" wechatOAuthDefaultFrontendCB = "/auth/wechat/callback" wechatOAuthProviderKey = "wechat-main" + wechatOAuthLegacyProviderKey = "wechat" wechatOAuthIntentLogin = "login" wechatOAuthIntentBind = "bind_current_user" @@ -483,18 +484,20 @@ func (h *AuthHandler) ensureWeChatBindOwnership( return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") } - identity, err := client.AuthIdentity.Query(). + identities, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("wechat"), - authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...), authidentity.ProviderSubjectEQ(strings.TrimSpace(providerSubject)), ). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return infraerrors.InternalServer("WECHAT_BIND_LOOKUP_FAILED", "failed to inspect wechat identity ownership").WithCause(err) } - if identity != nil && identity.UserID != userID { - return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + for _, identity := range identities { + if identity != nil && identity.UserID != userID { + return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } } channelSubject = strings.TrimSpace(channelSubject) @@ -503,21 +506,23 @@ func (h *AuthHandler) ensureWeChatBindOwnership( return nil } - channel, err := client.AuthIdentityChannel.Query(). + channels, err := client.AuthIdentityChannel.Query(). Where( authidentitychannel.ProviderTypeEQ("wechat"), - authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey), + authidentitychannel.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...), authidentitychannel.ChannelEQ(strings.TrimSpace(cfg.mode)), authidentitychannel.ChannelAppIDEQ(channelAppID), authidentitychannel.ChannelSubjectEQ(channelSubject), ). WithIdentity(). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return infraerrors.InternalServer("WECHAT_BIND_CHANNEL_LOOKUP_FAILED", "failed to inspect wechat identity channel ownership").WithCause(err) } - if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID { - return infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + for _, channel := range channels { + if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID { + return infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } } return nil } @@ -533,14 +538,34 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID( return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") } + providerType := strings.TrimSpace(identity.ProviderType) + providerSubject := strings.TrimSpace(identity.ProviderSubject) + providerKeys := wechatCompatibleProviderKeys(identity.ProviderKey) + if providerSubject != "" { + records, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyIn(providerKeys...), + authidentity.ProviderSubjectEQ(providerSubject), + ). + WithUser(). + All(ctx) + if err != nil { + return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + if user, err := singleWeChatIdentityUser(records); err != nil || user != nil { + return user, err + } + } + openid = strings.TrimSpace(openid) channel := strings.TrimSpace(cfg.mode) channelAppID := strings.TrimSpace(cfg.appID) if openid != "" && channel != "" && channelAppID != "" { - record, err := client.AuthIdentityChannel.Query(). + records, err := client.AuthIdentityChannel.Query(). Where( - authidentitychannel.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)), - authidentitychannel.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)), + authidentitychannel.ProviderTypeEQ(providerType), + authidentitychannel.ProviderKeyIn(providerKeys...), authidentitychannel.ChannelEQ(channel), authidentitychannel.ChannelAppIDEQ(channelAppID), authidentitychannel.ChannelSubjectEQ(openid), @@ -548,12 +573,12 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID( WithIdentity(func(q *dbent.AuthIdentityQuery) { q.WithUser() }). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) } - if record != nil && record.Edges.Identity != nil && record.Edges.Identity.Edges.User != nil { - return record.Edges.Identity.Edges.User, nil + if user, err := singleWeChatChannelUser(records); err != nil || user != nil { + return user, err } } @@ -561,21 +586,64 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID( return nil, nil } - record, err := client.AuthIdentity.Query(). + records, err := client.AuthIdentity.Query(). Where( - authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)), - authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)), + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyIn(providerKeys...), authidentity.ProviderSubjectEQ(openid), ). WithUser(). - Only(ctx) + All(ctx) if err != nil { - if dbent.IsNotFound(err) { - return nil, nil - } return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) } - return record.Edges.User, nil + return singleWeChatIdentityUser(records) +} + +func wechatCompatibleProviderKeys(providerKey string) []string { + preferred := strings.TrimSpace(providerKey) + if preferred == "" { + preferred = wechatOAuthProviderKey + } + keys := []string{preferred} + if !strings.EqualFold(preferred, wechatOAuthLegacyProviderKey) { + keys = append(keys, wechatOAuthLegacyProviderKey) + } + return keys +} + +func singleWeChatIdentityUser(records []*dbent.AuthIdentity) (*dbent.User, error) { + var resolved *dbent.User + for _, record := range records { + if record == nil || record.Edges.User == nil { + continue + } + if resolved == nil { + resolved = record.Edges.User + continue + } + if resolved.ID != record.Edges.User.ID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + } + return resolved, nil +} + +func singleWeChatChannelUser(records []*dbent.AuthIdentityChannel) (*dbent.User, error) { + var resolved *dbent.User + for _, record := range records { + if record == nil || record.Edges.Identity == nil || record.Edges.Identity.Edges.User == nil { + continue + } + if resolved == nil { + resolved = record.Edges.Identity.Edges.User + continue + } + if resolved.ID != record.Edges.Identity.Edges.User.ID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + } + return resolved, nil } func (h *AuthHandler) ensureWeChatRuntimeIdentityBinding( diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index dd022fb9..def9d5d6 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -467,6 +467,88 @@ func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) { require.Zero(t, count) } +func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + owner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(owner.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthLegacyProviderKey). + SetProviderSubject("union-456"). + SetMetadata(map[string]any{"unionid": "union-456"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) { t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") @@ -703,6 +785,116 @@ func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { require.Equal(t, repairedIdentity.ID, channel.IdentityID) } +func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy Canonical","headimgurl":"https://cdn.example/legacy-canonical.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + legacyUser, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + legacyIdentity, err := client.AuthIdentity.Create(). + SetUserID(legacyUser.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthLegacyProviderKey). + SetProviderSubject("union-456"). + SetMetadata(map[string]any{"unionid": "union-456"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, session.TargetUserID) + require.Equal(t, legacyUser.ID, *session.TargetUserID) + require.Equal(t, legacyUser.Email, session.ResolvedEmail) + + repairedIdentity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("union-456"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, legacyIdentity.ID, repairedIdentity.ID) + require.Equal(t, legacyUser.ID, repairedIdentity.UserID) + + legacyIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthLegacyProviderKey), + authidentity.ProviderSubjectEQ("union-456"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, legacyIdentityCount) + + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open-app"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, repairedIdentity.ID, channel.IdentityID) +} + func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { t.Helper() diff --git a/backend/migrations/113_normalize_legacy_wechat_provider_key.sql b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql new file mode 100644 index 00000000..15610af0 --- /dev/null +++ b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql @@ -0,0 +1,89 @@ +UPDATE auth_identities AS ai +SET + provider_key = 'wechat-main', + metadata = COALESCE(ai.metadata, '{}'::jsonb) || jsonb_build_object( + 'legacy_provider_key', 'wechat', + 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key' + ), + updated_at = NOW() +WHERE ai.provider_type = 'wechat' + AND ai.provider_key = 'wechat' + AND NOT EXISTS ( + SELECT 1 + FROM auth_identities AS canon + WHERE canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.provider_subject = ai.provider_subject + ); + +UPDATE auth_identity_channels AS channel +SET + provider_key = 'wechat-main', + metadata = COALESCE(channel.metadata, '{}'::jsonb) || jsonb_build_object( + 'legacy_provider_key', 'wechat', + 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key' + ), + updated_at = NOW() +WHERE channel.provider_type = 'wechat' + AND channel.provider_key = 'wechat' + AND NOT EXISTS ( + SELECT 1 + FROM auth_identity_channels AS canon + WHERE canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.channel = channel.channel + AND canon.channel_app_id = channel.channel_app_id + AND canon.channel_subject = channel.channel_subject + ); + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_provider_key_conflict', + CAST(ai.id AS TEXT), + jsonb_build_object( + 'legacy_identity_id', ai.id, + 'legacy_user_id', ai.user_id, + 'provider_subject', ai.provider_subject, + 'canonical_identity_id', canon.id, + 'canonical_user_id', canon.user_id, + 'same_user', canon.user_id = ai.user_id, + 'migration', '113_normalize_legacy_wechat_provider_key' + ) +FROM auth_identities AS ai +JOIN auth_identities AS canon + ON canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.provider_subject = ai.provider_subject +WHERE ai.provider_type = 'wechat' + AND ai.provider_key = 'wechat' +ON CONFLICT (report_type, report_key) DO NOTHING; + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_channel_provider_key_conflict', + CAST(channel.id AS TEXT), + jsonb_build_object( + 'legacy_channel_id', channel.id, + 'legacy_identity_id', channel.identity_id, + 'canonical_channel_id', canon.id, + 'canonical_identity_id', canon.identity_id, + 'channel', channel.channel, + 'channel_app_id', channel.channel_app_id, + 'channel_subject', channel.channel_subject, + 'same_user', COALESCE(legacy_identity.user_id = canonical_identity.user_id, FALSE), + 'migration', '113_normalize_legacy_wechat_provider_key' + ) +FROM auth_identity_channels AS channel +JOIN auth_identity_channels AS canon + ON canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.channel = channel.channel + AND canon.channel_app_id = channel.channel_app_id + AND canon.channel_subject = channel.channel_subject +LEFT JOIN auth_identities AS legacy_identity + ON legacy_identity.id = channel.identity_id +LEFT JOIN auth_identities AS canonical_identity + ON canonical_identity.id = canon.identity_id +WHERE channel.provider_type = 'wechat' + AND channel.provider_key = 'wechat' +ON CONFLICT (report_type, report_key) DO NOTHING; From b30982219962c96d4e7c33d6aaeff4171eccbe88 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 21:46:24 +0800 Subject: [PATCH 044/326] fix: tighten legacy payment provider resolution --- backend/internal/service/payment_refund.go | 38 ++++++----- .../service/payment_webhook_provider_test.go | 64 +++++++++++++++++++ 2 files changed, 87 insertions(+), 15 deletions(-) diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go index 01eecff8..57469fa3 100644 --- a/backend/internal/service/payment_refund.go +++ b/backend/internal/service/payment_refund.go @@ -21,7 +21,7 @@ import ( // getOrderProviderInstance looks up the provider instance that processed this order. // For legacy orders without provider_instance_id, it resolves only when the -// enabled instance is uniquely identifiable from the stored order fields. +// historical instance is uniquely identifiable from the stored order fields. func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { if s == nil || s.entClient == nil || o == nil { return nil, nil @@ -40,47 +40,55 @@ func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent. } func (s *PaymentService) resolveUniqueLegacyOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { + paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType)) providerKey := strings.TrimSpace(psStringValue(o.ProviderKey)) if providerKey != "" { instances, err := s.entClient.PaymentProviderInstance.Query(). - Where( - paymentproviderinstance.EnabledEQ(true), - paymentproviderinstance.ProviderKeyEQ(providerKey), - ). + Where(paymentproviderinstance.ProviderKeyEQ(providerKey)). All(ctx) if err != nil { return nil, err } - if len(instances) == 1 { - return instances[0], nil + matched := psFilterLegacyOrderProviderInstances(paymentType, instances) + if len(matched) == 1 { + return matched[0], nil } return nil, nil } - paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType)) if paymentType == "" { return nil, nil } instances, err := s.entClient.PaymentProviderInstance.Query(). - Where(paymentproviderinstance.EnabledEQ(true)). All(ctx) if err != nil { return nil, err } - var matched []*dbent.PaymentProviderInstance - for _, inst := range instances { - if psLegacyOrderMatchesInstance(paymentType, inst) { - matched = append(matched, inst) - } - } + matched := psFilterLegacyOrderProviderInstances(paymentType, instances) if len(matched) == 1 { return matched[0], nil } return nil, nil } +func psFilterLegacyOrderProviderInstances(orderPaymentType string, instances []*dbent.PaymentProviderInstance) []*dbent.PaymentProviderInstance { + if len(instances) == 0 { + return nil + } + if strings.TrimSpace(orderPaymentType) == "" { + return instances + } + var matched []*dbent.PaymentProviderInstance + for _, inst := range instances { + if psLegacyOrderMatchesInstance(orderPaymentType, inst) { + matched = append(matched, inst) + } + } + return matched +} + func psLegacyOrderMatchesInstance(orderPaymentType string, inst *dbent.PaymentProviderInstance) bool { if inst == nil { return false diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go index 33e4186d..4f0b6848 100644 --- a/backend/internal/service/payment_webhook_provider_test.go +++ b/backend/internal/service/payment_webhook_provider_test.go @@ -141,6 +141,70 @@ func TestGetOrderProviderInstanceLeavesAmbiguousLegacyOrderUnresolved(t *testing require.Nil(t, got) } +func TestGetOrderProviderInstanceLeavesLegacyProviderKeyUnresolvedWhenHistoricalInstancesConflict(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-disabled-legacy"). + SetConfig("{}"). + SetSupportedTypes("stripe"). + SetEnabled(false). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-enabled-current"). + SetConfig("{}"). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + providerKey := payment.TypeStripe + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeStripe, + ProviderKey: &providerKey, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.Nil(t, got) +} + +func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupported(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-only"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + providerKey := payment.TypeWxpay + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeAlipayDirect, + ProviderKey: &providerKey, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.Nil(t, got) +} + func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) { ctx := context.Background() client := newPaymentConfigServiceTestClient(t) From 31d0183d45b3c5d2a6692daebec58625a393fe38 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 21:51:57 +0800 Subject: [PATCH 045/326] fix: normalize repository email lookups --- backend/internal/repository/user_repo.go | 30 +++++++- .../user_repo_email_lookup_unit_test.go | 69 +++++++++++++++++++ 2 files changed, 96 insertions(+), 3 deletions(-) create mode 100644 backend/internal/repository/user_repo_email_lookup_unit_test.go diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 0c607ecc..b5efd19d 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -12,6 +12,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" dbgroup "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/predicate" dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -104,10 +105,20 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, } func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) { - m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx) + matches, err := r.client.User.Query(). + Where(userEmailLookupPredicate(email)). + Order(dbent.Asc(dbuser.FieldID)). + All(ctx) if err != nil { - return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) + return nil, err } + if len(matches) == 0 { + return nil, service.ErrUserNotFound + } + if len(matches) > 1 { + return nil, fmt.Errorf("normalized email lookup matched multiple users for %q", strings.TrimSpace(email)) + } + m := matches[0] out := userEntityToService(m) groups, err := r.loadAllowedGroups(ctx, []int64{m.ID}) @@ -469,7 +480,20 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount } func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { - return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx) + return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx) +} + +func userEmailLookupPredicate(email string) predicate.User { + normalized := strings.TrimSpace(email) + if normalized == "" { + return dbuser.EmailEQ(email) + } + return predicate.User(func(s *entsql.Selector) { + s.Where(entsql.ExprP( + fmt.Sprintf("LOWER(TRIM(%s)) = LOWER(TRIM(?))", s.C(dbuser.FieldEmail)), + normalized, + )) + }) } func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { diff --git a/backend/internal/repository/user_repo_email_lookup_unit_test.go b/backend/internal/repository/user_repo_email_lookup_unit_test.go new file mode 100644 index 00000000..d42ce9ac --- /dev/null +++ b/backend/internal/repository/user_repo_email_lookup_unit_test.go @@ -0,0 +1,69 @@ +package repository + +import ( + "context" + "database/sql" + "testing" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:user_repo_email_lookup?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + return newUserRepositoryWithSQL(client, db), client +} + +func TestUserRepositoryGetByEmailNormalizesLegacySpacingAndCase(t *testing.T) { + repo, _ := newUserEntRepo(t) + ctx := context.Background() + + err := repo.Create(ctx, &service.User{ + Email: " Legacy@Example.com ", + Username: "legacy-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + }) + require.NoError(t, err) + + got, err := repo.GetByEmail(ctx, "legacy@example.com") + require.NoError(t, err) + require.Equal(t, " Legacy@Example.com ", got.Email) +} + +func TestUserRepositoryExistsByEmailNormalizesLegacySpacingAndCase(t *testing.T) { + repo, _ := newUserEntRepo(t) + ctx := context.Background() + + err := repo.Create(ctx, &service.User{ + Email: " Legacy@Example.com ", + Username: "legacy-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + }) + require.NoError(t, err) + + exists, err := repo.ExistsByEmail(ctx, " LEGACY@example.com ") + require.NoError(t, err) + require.True(t, exists) +} From aaf4946b27ebc03163d0832b7946234f44b6b769 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 21:59:03 +0800 Subject: [PATCH 046/326] fix: normalize pending oauth email lookups --- .../handler/auth_oauth_pending_flow.go | 47 ++++++++-- .../handler/auth_oauth_pending_flow_test.go | 85 +++++++++++++++++++ 2 files changed, 126 insertions(+), 6 deletions(-) diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index fd35e4e5..1b3c1380 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -3,6 +3,7 @@ package handler import ( "context" "errors" + "fmt" "io" "net/http" "net/url" @@ -12,12 +13,14 @@ import ( "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" dbuser "github.com/Wei-Shaw/sub2api/ent/user" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" + entsql "entgo.io/ent/dialect/sql" "github.com/gin-gonic/gin" ) @@ -531,11 +534,9 @@ func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing") } - userEntity, err := client.User.Query(). - Where(dbuser.EmailEQ(email)). - Only(ctx) + userEntity, err := findUserByNormalizedEmail(ctx, client, email) if err != nil { - if dbent.IsNotFound(err) { + if errors.Is(err, service.ErrUserNotFound) { return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found") } return 0, err @@ -543,6 +544,40 @@ func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, return userEntity.ID, nil } +func userNormalizedEmailPredicate(email string) predicate.User { + normalized := strings.TrimSpace(email) + if normalized == "" { + return dbuser.EmailEQ(email) + } + return predicate.User(func(s *entsql.Selector) { + s.Where(entsql.ExprP( + fmt.Sprintf("LOWER(TRIM(%s)) = LOWER(TRIM(?))", s.C(dbuser.FieldEmail)), + normalized, + )) + }) +} + +func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email string) (*dbent.User, error) { + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + matches, err := client.User.Query(). + Where(userNormalizedEmailPredicate(email)). + Order(dbent.Asc(dbuser.FieldID)). + All(ctx) + if err != nil { + return nil, err + } + if len(matches) == 0 { + return nil, service.ErrUserNotFound + } + if len(matches) > 1 { + return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users") + } + return matches[0], nil +} + func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string { if session == nil { return nil @@ -1102,8 +1137,8 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) } email := strings.TrimSpace(strings.ToLower(req.Email)) - existingUser, err := client.User.Query().Where(dbuser.EmailEQ(email)).Only(c.Request.Context()) - if err != nil && !dbent.IsNotFound(err) { + existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email) + if err != nil && !errors.Is(err, service.ErrUserNotFound) { response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")) return } diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 89accd60..8c468fdc 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -642,6 +642,60 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsAdoptExistingUserByEmailState require.Zero(t, identityCount) } +func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail(" Owner@Example.com "). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("existing-email-normalized-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-existing-normalized-123"). + SetBrowserSessionKey("existing-email-normalized-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Existing OIDC User", + "suggested_avatar_url": "https://cdn.example/existing.png", + }). + SetRedirectTo("/dashboard"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-normalized-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.Equal(t, "adopt_existing_user_by_email", payload["intent"]) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.TargetUserID) + require.Equal(t, existingUser.ID, *storedSession.TargetUserID) + require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) +} + func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() @@ -884,6 +938,37 @@ func TestBindOIDCOAuthLoginAppliesFirstBindGrantOnce(t *testing.T) { require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind")) } +func TestResolvePendingOAuthTargetUserIDNormalizesLegacySpacingAndCase(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + _ = handler + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail(" Owner@Example.com "). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("resolve-target-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-target-123"). + SetResolvedEmail("owner@example.com"). + SetBrowserSessionKey("resolve-target-browser-session-key"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session) + require.NoError(t, err) + require.Equal(t, existingUser.ID, resolvedUserID) +} + func TestBindOIDCOAuthLoginReturns2FAChallengeWhenUserHasTotp(t *testing.T) { totpCache := &oauthPendingFlowTotpCacheStub{} handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ From bbc4aed3d91c8b3312c320843e3d4bea54e17d80 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 20 Apr 2026 22:01:09 +0800 Subject: [PATCH 047/326] =?UTF-8?q?fix(openai):=20=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E5=B7=B2=E4=B8=8B=E7=BA=BF=20Codex=20=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=B9=B6=E4=BF=AE=E5=A4=8D=E5=BD=92=E4=B8=80=E5=8C=96=E5=85=9C?= =?UTF-8?q?=E5=BA=95=E5=89=AF=E4=BD=9C=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - backend: 删除 gpt-5 / 5.1 / 5.1-codex / 5.1-codex-max / 5.1-codex-mini / 5.2-codex / 5.4-nano 的内置映射与 DefaultModels 条目 - backend: normalizeCodexModel 默认兜底由 gpt-5.1 改为 gpt-5.4,gpt-5.3-codex-spark 独立保留映射 - backend: 修复 isOpenAIGPT54Model 与 shouldAutoInjectPromptCacheKeyForCompat 对 claude / gpt-4o 的误判(之前依赖 gpt-5.1 作为非 GPT 族的隐式 sentinel,改后需要显式前缀守卫) - backend: 清理 billing_service 中已不可达的 fallback 价格与 switch 分支 - frontend: 从白名单、OpenCode 配置、预设映射中移除已下线模型 - 同步更新所有相关单测 Refs: #1758, parallels upstream #1759 but adds downstream guard fixes --- backend/internal/pkg/openai/constants.go | 9 +- backend/internal/service/billing_service.go | 51 ++-------- .../internal/service/billing_service_test.go | 29 +----- .../service/openai_codex_transform.go | 76 +++------------ .../service/openai_codex_transform_test.go | 30 ++++-- .../service/openai_compat_prompt_cache_key.go | 10 +- .../service/openai_model_mapping_test.go | 14 +-- frontend/src/components/keys/UseKeyModal.vue | 92 ------------------- .../keys/__tests__/UseKeyModal.spec.ts | 4 +- .../__tests__/useModelWhitelist.spec.ts | 19 +++- frontend/src/composables/useModelWhitelist.ts | 15 +-- 11 files changed, 84 insertions(+), 265 deletions(-) diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go index 49e38bf8..980f058d 100644 --- a/backend/internal/pkg/openai/constants.go +++ b/backend/internal/pkg/openai/constants.go @@ -17,16 +17,9 @@ type Model struct { var DefaultModels = []Model{ {ID: "gpt-5.4", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4"}, {ID: "gpt-5.4-mini", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Mini"}, - {ID: "gpt-5.4-nano", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Nano"}, {ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"}, {ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"}, {ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"}, - {ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"}, - {ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"}, - {ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"}, - {ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"}, - {ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"}, - {ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"}, } // DefaultModelIDs returns the default model ID list @@ -39,7 +32,7 @@ func DefaultModelIDs() []string { } // DefaultTestModel default model for testing OpenAI accounts -const DefaultTestModel = "gpt-5.1-codex" +const DefaultTestModel = "gpt-5.4" // DefaultInstructions default instructions for non-Codex CLI requests // Content loaded from instructions.txt at compile time diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index c9f32b3b..a45203a3 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -203,17 +203,6 @@ func (s *BillingService) initFallbackPricing() { SupportsCacheBreakdown: false, } - // OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费) - s.fallbackPrices["gpt-5.1"] = &ModelPricing{ - InputPricePerToken: 1.25e-6, // $1.25 per MTok - InputPricePerTokenPriority: 2.5e-6, // $2.5 per MTok - OutputPricePerToken: 10e-6, // $10 per MTok - OutputPricePerTokenPriority: 20e-6, // $20 per MTok - CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok - CacheReadPricePerToken: 0.125e-6, - CacheReadPricePerTokenPriority: 0.25e-6, - SupportsCacheBreakdown: false, - } // OpenAI GPT-5.4(业务指定价格) s.fallbackPrices["gpt-5.4"] = &ModelPricing{ InputPricePerToken: 2.5e-6, // $2.5 per MTok @@ -234,12 +223,6 @@ func (s *BillingService) initFallbackPricing() { CacheReadPricePerToken: 7.5e-8, SupportsCacheBreakdown: false, } - s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{ - InputPricePerToken: 2e-7, - OutputPricePerToken: 1.25e-6, - CacheReadPricePerToken: 2e-8, - SupportsCacheBreakdown: false, - } // OpenAI GPT-5.2(本地兜底) s.fallbackPrices["gpt-5.2"] = &ModelPricing{ InputPricePerToken: 1.75e-6, @@ -251,8 +234,8 @@ func (s *BillingService) initFallbackPricing() { CacheReadPricePerTokenPriority: 0.35e-6, SupportsCacheBreakdown: false, } - // Codex 族兜底统一按 GPT-5.1 Codex 价格计费 - s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{ + // Codex 族兜底统一按 GPT-5.3 Codex 价格计费 + s.fallbackPrices["gpt-5.3-codex"] = &ModelPricing{ InputPricePerToken: 1.5e-6, // $1.5 per MTok InputPricePerTokenPriority: 3e-6, // $3 per MTok OutputPricePerToken: 12e-6, // $12 per MTok @@ -262,17 +245,6 @@ func (s *BillingService) initFallbackPricing() { CacheReadPricePerTokenPriority: 0.3e-6, SupportsCacheBreakdown: false, } - s.fallbackPrices["gpt-5.2-codex"] = &ModelPricing{ - InputPricePerToken: 1.75e-6, - InputPricePerTokenPriority: 3.5e-6, - OutputPricePerToken: 14e-6, - OutputPricePerTokenPriority: 28e-6, - CacheCreationPricePerToken: 1.75e-6, - CacheReadPricePerToken: 0.175e-6, - CacheReadPricePerTokenPriority: 0.35e-6, - SupportsCacheBreakdown: false, - } - s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"] } // getFallbackPricing 根据模型系列获取回退价格 @@ -318,20 +290,12 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { switch normalized { case "gpt-5.4-mini": return s.fallbackPrices["gpt-5.4-mini"] - case "gpt-5.4-nano": - return s.fallbackPrices["gpt-5.4-nano"] case "gpt-5.4": return s.fallbackPrices["gpt-5.4"] case "gpt-5.2": return s.fallbackPrices["gpt-5.2"] - case "gpt-5.2-codex": - return s.fallbackPrices["gpt-5.2-codex"] - case "gpt-5.3-codex": + case "gpt-5.3-codex", "gpt-5.3-codex-spark": return s.fallbackPrices["gpt-5.3-codex"] - case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest": - return s.fallbackPrices["gpt-5.1-codex"] - case "gpt-5.1": - return s.fallbackPrices["gpt-5.1"] } } @@ -667,8 +631,13 @@ func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens } func isOpenAIGPT54Model(model string) bool { - normalized := normalizeCodexModel(strings.TrimSpace(strings.ToLower(model))) - return normalized == "gpt-5.4" + trimmed := strings.TrimSpace(strings.ToLower(model)) + // 仅当模型字符串实际属于 GPT-5/Codex 族时才做归一判定,避免 normalizeCodexModel + // 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o)误识别为 gpt-5.4。 + if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") { + return false + } + return normalizeCodexModel(trimmed) == "gpt-5.4" } // CalculateCostWithConfig 使用配置中的默认倍率计算费用 diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index fc8361c7..222abd69 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -123,15 +123,6 @@ func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) { require.Contains(t, err.Error(), "pricing not found") } -func TestGetModelPricing_OpenAIGPT51Fallback(t *testing.T) { - svc := newTestBillingService() - - pricing, err := svc.GetModelPricing("gpt-5.1") - require.NoError(t, err) - require.NotNil(t, pricing) - require.InDelta(t, 1.25e-6, pricing.InputPricePerToken, 1e-12) -} - func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) { svc := newTestBillingService() @@ -158,18 +149,6 @@ func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) { require.Zero(t, pricing.LongContextInputThreshold) } -func TestGetModelPricing_OpenAIGPT54NanoFallback(t *testing.T) { - svc := newTestBillingService() - - pricing, err := svc.GetModelPricing("gpt-5.4-nano") - require.NoError(t, err) - require.NotNil(t, pricing) - require.InDelta(t, 2e-7, pricing.InputPricePerToken, 1e-12) - require.InDelta(t, 1.25e-6, pricing.OutputPricePerToken, 1e-12) - require.InDelta(t, 2e-8, pricing.CacheReadPricePerToken, 1e-12) - require.Zero(t, pricing.LongContextInputThreshold) -} - func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) { svc := newTestBillingService() @@ -204,13 +183,13 @@ func TestGetFallbackPricing_FamilyMatching(t *testing.T) { {name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6}, {name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6}, {name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true}, - {name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6}, {name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6}, {name: "openai gpt5.4 mini", model: "gpt-5.4-mini", expectedInput: 7.5e-7}, - {name: "openai gpt5.4 nano", model: "gpt-5.4-nano", expectedInput: 2e-7}, {name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6}, - {name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6}, - {name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6}, + {name: "openai gpt5.3 codex spark", model: "gpt-5.3-codex-spark", expectedInput: 1.5e-6}, + {name: "openai legacy gpt5.1 falls back to gpt5.4", model: "gpt-5.1", expectedInput: 2.5e-6}, + {name: "openai legacy gpt5.1 codex falls back to gpt5.3 codex", model: "gpt-5.1-codex", expectedInput: 1.5e-6}, + {name: "openai legacy codex mini latest falls back to gpt5.3 codex", model: "codex-mini-latest", expectedInput: 1.5e-6}, {name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true}, {name: "non supported family", model: "qwen-max", expectNilPricing: true}, } diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index a266d6a0..457309d3 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -8,7 +8,6 @@ import ( var codexModelMap = map[string]string{ "gpt-5.4": "gpt-5.4", "gpt-5.4-mini": "gpt-5.4-mini", - "gpt-5.4-nano": "gpt-5.4-nano", "gpt-5.4-none": "gpt-5.4", "gpt-5.4-low": "gpt-5.4", "gpt-5.4-medium": "gpt-5.4", @@ -22,52 +21,21 @@ var codexModelMap = map[string]string{ "gpt-5.3-high": "gpt-5.3-codex", "gpt-5.3-xhigh": "gpt-5.3-codex", "gpt-5.3-codex": "gpt-5.3-codex", - "gpt-5.3-codex-spark": "gpt-5.3-codex", - "gpt-5.3-codex-spark-low": "gpt-5.3-codex", - "gpt-5.3-codex-spark-medium": "gpt-5.3-codex", - "gpt-5.3-codex-spark-high": "gpt-5.3-codex", - "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex-spark": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-low": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-medium": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark", "gpt-5.3-codex-low": "gpt-5.3-codex", "gpt-5.3-codex-medium": "gpt-5.3-codex", "gpt-5.3-codex-high": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex", - "gpt-5.1-codex": "gpt-5.1-codex", - "gpt-5.1-codex-low": "gpt-5.1-codex", - "gpt-5.1-codex-medium": "gpt-5.1-codex", - "gpt-5.1-codex-high": "gpt-5.1-codex", - "gpt-5.1-codex-max": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-low": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-medium": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-high": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max", "gpt-5.2": "gpt-5.2", "gpt-5.2-none": "gpt-5.2", "gpt-5.2-low": "gpt-5.2", "gpt-5.2-medium": "gpt-5.2", "gpt-5.2-high": "gpt-5.2", "gpt-5.2-xhigh": "gpt-5.2", - "gpt-5.2-codex": "gpt-5.2-codex", - "gpt-5.2-codex-low": "gpt-5.2-codex", - "gpt-5.2-codex-medium": "gpt-5.2-codex", - "gpt-5.2-codex-high": "gpt-5.2-codex", - "gpt-5.2-codex-xhigh": "gpt-5.2-codex", - "gpt-5.1-codex-mini": "gpt-5.1-codex-mini", - "gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini", - "gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini", - "gpt-5.1": "gpt-5.1", - "gpt-5.1-none": "gpt-5.1", - "gpt-5.1-low": "gpt-5.1", - "gpt-5.1-medium": "gpt-5.1", - "gpt-5.1-high": "gpt-5.1", - "gpt-5.1-chat-latest": "gpt-5.1", - "gpt-5-codex": "gpt-5.1-codex", - "codex-mini-latest": "gpt-5.1-codex-mini", - "gpt-5-codex-mini": "gpt-5.1-codex-mini", - "gpt-5-codex-mini-medium": "gpt-5.1-codex-mini", - "gpt-5-codex-mini-high": "gpt-5.1-codex-mini", - "gpt-5": "gpt-5.1", - "gpt-5-mini": "gpt-5.1", - "gpt-5-nano": "gpt-5.1", } type codexTransformResult struct { @@ -220,7 +188,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact func normalizeCodexModel(model string) string { if model == "" { - return "gpt-5.1" + return "gpt-5.4" } modelID := model @@ -238,49 +206,29 @@ func normalizeCodexModel(model string) string { if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") { return "gpt-5.4-mini" } - if strings.Contains(normalized, "gpt-5.4-nano") || strings.Contains(normalized, "gpt 5.4 nano") { - return "gpt-5.4-nano" - } if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") { return "gpt-5.4" } - if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") { - return "gpt-5.2-codex" - } if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") { return "gpt-5.2" } + if strings.Contains(normalized, "gpt-5.3-codex-spark") || strings.Contains(normalized, "gpt 5.3 codex spark") { + return "gpt-5.3-codex-spark" + } if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") { return "gpt-5.3-codex" } if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") { return "gpt-5.3-codex" } - if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") { - return "gpt-5.1-codex-max" - } - if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") { - return "gpt-5.1-codex-mini" - } - if strings.Contains(normalized, "codex-mini-latest") || - strings.Contains(normalized, "gpt-5-codex-mini") || - strings.Contains(normalized, "gpt 5 codex mini") { - return "codex-mini-latest" - } - if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") { - return "gpt-5.1-codex" - } - if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") { - return "gpt-5.1" - } if strings.Contains(normalized, "codex") { - return "gpt-5.1-codex" + return "gpt-5.3-codex" } if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") { - return "gpt-5.1" + return "gpt-5.4" } - return "gpt-5.1" + return "gpt-5.4" } func normalizeOpenAIModelForUpstream(account *Account, model string) string { diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 993ade07..22264f5e 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -240,15 +240,13 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { "gpt 5.4": "gpt-5.4", "gpt-5.4-mini": "gpt-5.4-mini", "gpt 5.4 mini": "gpt-5.4-mini", - "gpt-5.4-nano": "gpt-5.4-nano", - "gpt 5.4 nano": "gpt-5.4-nano", "gpt-5.3": "gpt-5.3-codex", "gpt-5.3-codex": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex", - "gpt-5.3-codex-spark": "gpt-5.3-codex", - "gpt 5.3 codex spark": "gpt-5.3-codex", - "gpt-5.3-codex-spark-high": "gpt-5.3-codex", - "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex-spark": "gpt-5.3-codex-spark", + "gpt 5.3 codex spark": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark", "gpt 5.3 codex": "gpt-5.3-codex", } @@ -257,6 +255,26 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { } } +func TestNormalizeCodexModel_RemovedModelsFallbackToSupportedTargets(t *testing.T) { + cases := map[string]string{ + "": "gpt-5.4", + "gpt-5": "gpt-5.4", + "gpt-5-mini": "gpt-5.4", + "gpt-5-nano": "gpt-5.4", + "gpt-5.1": "gpt-5.4", + "gpt-5.1-codex": "gpt-5.3-codex", + "gpt-5.1-codex-max": "gpt-5.3-codex", + "gpt-5.1-codex-mini": "gpt-5.3-codex", + "gpt-5.2-codex": "gpt-5.2", + "codex-mini-latest": "gpt-5.3-codex", + "gpt-5-codex": "gpt-5.3-codex", + } + + for input, expected := range cases { + require.Equal(t, expected, normalizeCodexModel(input)) + } +} + func TestApplyCodexOAuthTransform_PreservesBareSparkModel(t *testing.T) { reqBody := map[string]any{ "model": "gpt-5.3-codex-spark", diff --git a/backend/internal/service/openai_compat_prompt_cache_key.go b/backend/internal/service/openai_compat_prompt_cache_key.go index 88e16a4d..fcd27f19 100644 --- a/backend/internal/service/openai_compat_prompt_cache_key.go +++ b/backend/internal/service/openai_compat_prompt_cache_key.go @@ -10,8 +10,14 @@ import ( const compatPromptCacheKeyPrefix = "compat_cc_" func shouldAutoInjectPromptCacheKeyForCompat(model string) bool { - switch normalizeCodexModel(strings.TrimSpace(model)) { - case "gpt-5.4", "gpt-5.3-codex": + trimmed := strings.TrimSpace(strings.ToLower(model)) + // 仅对 Codex OAuth 路径支持的 GPT-5 族开启自动注入,避免 normalizeCodexModel + // 的默认兜底把任意模型(如 gpt-4o、claude-*)误判为 gpt-5.4。 + if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") { + return false + } + switch normalizeCodexModel(trimmed) { + case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark": return true default: return false diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index cda7e369..35e7c250 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -69,14 +69,14 @@ func TestResolveOpenAIForwardModel(t *testing.T) { } } -func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *testing.T) { +func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t *testing.T) { account := &Account{ Credentials: map[string]any{}, } withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "")) - if withoutDefault != "gpt-5.1" { - t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.1") + if withoutDefault != "gpt-5.4" { + t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.4") } withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")) @@ -87,9 +87,9 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t * func TestNormalizeCodexModel(t *testing.T) { cases := map[string]string{ - "gpt-5.3-codex-spark": "gpt-5.3-codex", - "gpt-5.3-codex-spark-high": "gpt-5.3-codex", - "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex-spark": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark", "gpt-5.3": "gpt-5.3-codex", } @@ -111,7 +111,7 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) { name: "oauth keeps codex normalization behavior", account: &Account{Type: AccountTypeOAuth}, model: "gemini-3-flash-preview", - want: "gpt-5.1", + want: "gpt-5.4", }, { name: "apikey preserves custom compatible model", diff --git a/frontend/src/components/keys/UseKeyModal.vue b/frontend/src/components/keys/UseKeyModal.vue index 7770e658..b3679107 100644 --- a/frontend/src/components/keys/UseKeyModal.vue +++ b/frontend/src/components/keys/UseKeyModal.vue @@ -617,66 +617,6 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin } } const openaiModels = { - 'gpt-5-codex': { - name: 'GPT-5 Codex', - limit: { - context: 400000, - output: 128000 - }, - options: { - store: false - }, - variants: { - low: {}, - medium: {}, - high: {} - } - }, - 'gpt-5.1-codex': { - name: 'GPT-5.1 Codex', - limit: { - context: 400000, - output: 128000 - }, - options: { - store: false - }, - variants: { - low: {}, - medium: {}, - high: {} - } - }, - 'gpt-5.1-codex-max': { - name: 'GPT-5.1 Codex Max', - limit: { - context: 400000, - output: 128000 - }, - options: { - store: false - }, - variants: { - low: {}, - medium: {}, - high: {} - } - }, - 'gpt-5.1-codex-mini': { - name: 'GPT-5.1 Codex Mini', - limit: { - context: 400000, - output: 128000 - }, - options: { - store: false - }, - variants: { - low: {}, - medium: {}, - high: {} - } - }, 'gpt-5.2': { name: 'GPT-5.2', limit: { @@ -725,22 +665,6 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin xhigh: {} } }, - 'gpt-5.4-nano': { - name: 'GPT-5.4 Nano', - limit: { - context: 400000, - output: 128000 - }, - options: { - store: false - }, - variants: { - low: {}, - medium: {}, - high: {}, - xhigh: {} - } - }, 'gpt-5.3-codex-spark': { name: 'GPT-5.3 Codex Spark', limit: { @@ -773,22 +697,6 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin xhigh: {} } }, - 'gpt-5.2-codex': { - name: 'GPT-5.2 Codex', - limit: { - context: 400000, - output: 128000 - }, - options: { - store: false - }, - variants: { - low: {}, - medium: {}, - high: {}, - xhigh: {} - } - }, 'codex-mini-latest': { name: 'Codex Mini', limit: { diff --git a/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts b/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts index 98b5dede..f7db586a 100644 --- a/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts +++ b/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts @@ -17,7 +17,7 @@ vi.mock('@/composables/useClipboard', () => ({ import UseKeyModal from '../UseKeyModal.vue' describe('UseKeyModal', () => { - it('renders updated GPT-5.4 mini/nano names in OpenCode config', async () => { + it('renders GPT-5.4 mini entry in OpenCode config', async () => { const wrapper = mount(UseKeyModal, { props: { show: true, @@ -48,6 +48,6 @@ describe('UseKeyModal', () => { const codeBlock = wrapper.find('pre code') expect(codeBlock.exists()).toBe(true) expect(codeBlock.text()).toContain('"name": "GPT-5.4 Mini"') - expect(codeBlock.text()).toContain('"name": "GPT-5.4 Nano"') + expect(codeBlock.text()).not.toContain('"name": "GPT-5.4 Nano"') }) }) diff --git a/frontend/src/composables/__tests__/useModelWhitelist.spec.ts b/frontend/src/composables/__tests__/useModelWhitelist.spec.ts index 4061be4d..d35e3b12 100644 --- a/frontend/src/composables/__tests__/useModelWhitelist.spec.ts +++ b/frontend/src/composables/__tests__/useModelWhitelist.spec.ts @@ -12,10 +12,20 @@ describe('useModelWhitelist', () => { expect(models).toContain('gpt-5.4') expect(models).toContain('gpt-5.4-mini') - expect(models).toContain('gpt-5.4-nano') expect(models).toContain('gpt-5.4-2026-03-05') }) + it('openai 模型列表不再暴露已下线的 ChatGPT 登录 Codex 模型', () => { + const models = getModelsByPlatform('openai') + + expect(models).not.toContain('gpt-5') + expect(models).not.toContain('gpt-5.1') + expect(models).not.toContain('gpt-5.1-codex') + expect(models).not.toContain('gpt-5.1-codex-max') + expect(models).not.toContain('gpt-5.1-codex-mini') + expect(models).not.toContain('gpt-5.2-codex') + }) + it('antigravity 模型列表包含图片模型兼容项', () => { const models = getModelsByPlatform('antigravity') @@ -55,12 +65,11 @@ describe('useModelWhitelist', () => { }) }) - it('whitelist keeps GPT-5.4 mini and nano exact mappings', () => { - const mapping = buildModelMappingObject('whitelist', ['gpt-5.4-mini', 'gpt-5.4-nano'], []) + it('whitelist keeps GPT-5.4 mini exact mappings', () => { + const mapping = buildModelMappingObject('whitelist', ['gpt-5.4-mini'], []) expect(mapping).toEqual({ - 'gpt-5.4-mini': 'gpt-5.4-mini', - 'gpt-5.4-nano': 'gpt-5.4-nano' + 'gpt-5.4-mini': 'gpt-5.4-mini' }) }) }) diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts index a282ae7d..ddd7af48 100644 --- a/frontend/src/composables/useModelWhitelist.ts +++ b/frontend/src/composables/useModelWhitelist.ts @@ -13,19 +13,11 @@ const openaiModels = [ 'o1', 'o1-preview', 'o1-mini', 'o1-pro', 'o3', 'o3-mini', 'o3-pro', 'o4-mini', - // GPT-5 系列(同步后端定价文件) - 'gpt-5', 'gpt-5-2025-08-07', 'gpt-5-chat', 'gpt-5-chat-latest', - 'gpt-5-codex', 'gpt-5.3-codex-spark', 'gpt-5-pro', 'gpt-5-pro-2025-10-06', - 'gpt-5-mini', 'gpt-5-mini-2025-08-07', - 'gpt-5-nano', 'gpt-5-nano-2025-08-07', - // GPT-5.1 系列 - 'gpt-5.1', 'gpt-5.1-2025-11-13', 'gpt-5.1-chat-latest', - 'gpt-5.1-codex', 'gpt-5.1-codex-max', 'gpt-5.1-codex-mini', // GPT-5.2 系列 'gpt-5.2', 'gpt-5.2-2025-12-11', 'gpt-5.2-chat-latest', - 'gpt-5.2-codex', 'gpt-5.2-pro', 'gpt-5.2-pro-2025-12-11', + 'gpt-5.2-pro', 'gpt-5.2-pro-2025-12-11', // GPT-5.4 系列 - 'gpt-5.4', 'gpt-5.4-mini', 'gpt-5.4-nano', 'gpt-5.4-2026-03-05', + 'gpt-5.4', 'gpt-5.4-mini', 'gpt-5.4-2026-03-05', // GPT-5.3 系列 'gpt-5.3-codex', 'gpt-5.3-codex-spark', 'chatgpt-4o-latest', @@ -264,12 +256,9 @@ const openaiPresetMappings = [ { label: 'GPT-4.1', from: 'gpt-4.1', to: 'gpt-4.1', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' }, { label: 'o1', from: 'o1', to: 'o1', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' }, { label: 'o3', from: 'o3', to: 'o3', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' }, - { label: 'GPT-5', from: 'gpt-5', to: 'gpt-5', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' }, { label: 'GPT-5.3 Codex Spark', from: 'gpt-5.3-codex-spark', to: 'gpt-5.3-codex-spark', color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400' }, - { label: 'GPT-5.1', from: 'gpt-5.1', to: 'gpt-5.1', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' }, { label: 'GPT-5.2', from: 'gpt-5.2', to: 'gpt-5.2', color: 'bg-red-100 text-red-700 hover:bg-red-200 dark:bg-red-900/30 dark:text-red-400' }, { label: 'GPT-5.4', from: 'gpt-5.4', to: 'gpt-5.4', color: 'bg-rose-100 text-rose-700 hover:bg-rose-200 dark:bg-rose-900/30 dark:text-rose-400' }, - { label: 'GPT-5.1 Codex', from: 'gpt-5.1-codex', to: 'gpt-5.1-codex', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' }, { label: 'Haiku→5.4', from: 'claude-haiku-4-5-20251001', to: 'gpt-5.4', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' }, { label: 'Opus→5.4', from: 'claude-opus-4-6', to: 'gpt-5.4', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' }, { label: 'Sonnet→5.4', from: 'claude-sonnet-4-6', to: 'gpt-5.4', color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400' } From 3bd3027251dd5e008e60ba7299ef63dbc8f8a32c Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 22:05:33 +0800 Subject: [PATCH 048/326] feat: expose auth identity migration reports --- .../admin/admin_basic_handlers_test.go | 12 ++ .../handler/admin/admin_service_stub_test.go | 34 +++++ .../internal/handler/admin/user_handler.go | 25 ++++ backend/internal/server/routes/admin.go | 2 + backend/internal/service/admin_service.go | 135 ++++++++++++++++++ ..._service_identity_migration_report_test.go | 92 ++++++++++++ 6 files changed, 300 insertions(+) create mode 100644 backend/internal/service/admin_service_identity_migration_report_test.go diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index cba3ae21..ff9eec7e 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -22,6 +22,8 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { redeemHandler := NewRedeemHandler(adminSvc, nil) router.GET("/api/v1/admin/users", userHandler.List) + router.GET("/api/v1/admin/users/auth-identity-migration-reports/summary", userHandler.GetAuthIdentityMigrationReportSummary) + router.GET("/api/v1/admin/users/auth-identity-migration-reports", userHandler.ListAuthIdentityMigrationReports) router.GET("/api/v1/admin/users/:id", userHandler.GetByID) router.POST("/api/v1/admin/users", userHandler.Create) router.PUT("/api/v1/admin/users/:id", userHandler.Update) @@ -70,6 +72,16 @@ func TestUserHandlerEndpoints(t *testing.T) { router.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/auth-identity-migration-reports/summary", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/auth-identity-migration-reports?report_type=oidc_synthetic_email_requires_manual_recovery", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + rec = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil) router.ServeHTTP(rec, req) diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 6d1ef1b6..681c25c6 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -17,6 +17,7 @@ type stubAdminService struct { proxies []service.Proxy proxyCounts []service.ProxyWithAccountCount redeems []service.RedeemCode + migrationReports []service.AuthIdentityMigrationReport createdAccounts []*service.CreateAccountInput createdProxies []*service.CreateProxyInput updatedProxyIDs []int64 @@ -123,6 +124,15 @@ func newStubAdminService() *stubAdminService { proxies: []service.Proxy{proxy}, proxyCounts: []service.ProxyWithAccountCount{{Proxy: proxy, AccountCount: 1}}, redeems: []service.RedeemCode{redeem}, + migrationReports: []service.AuthIdentityMigrationReport{ + { + ID: 1, + ReportType: "oidc_synthetic_email_requires_manual_recovery", + ReportKey: "u-1", + Details: map[string]any{"user_id": 1}, + CreatedAt: now, + }, + }, } } @@ -167,6 +177,30 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, return map[string]any{"user_id": userID}, nil } +func (s *stubAdminService) ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]service.AuthIdentityMigrationReport, int64, error) { + if reportType == "" { + return s.migrationReports, int64(len(s.migrationReports)), nil + } + filtered := make([]service.AuthIdentityMigrationReport, 0, len(s.migrationReports)) + for _, report := range s.migrationReports { + if strings.EqualFold(report.ReportType, reportType) { + filtered = append(filtered, report) + } + } + return filtered, int64(len(filtered)), nil +} + +func (s *stubAdminService) GetAuthIdentityMigrationReportSummary(ctx context.Context) (*service.AuthIdentityMigrationReportSummary, error) { + summary := &service.AuthIdentityMigrationReportSummary{ + ByType: map[string]int64{}, + } + for _, report := range s.migrationReports { + summary.Total++ + summary.ByType[report.ReportType]++ + } + return summary, nil +} + func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) { return s.groups, int64(len(s.groups)), nil } diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 1453bd07..ee3fbb1e 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -172,6 +172,31 @@ func (h *UserHandler) GetByID(c *gin.Context) { response.Success(c, dto.UserFromServiceAdmin(user)) } +// GetAuthIdentityMigrationReportSummary returns aggregate migration report counts. +// GET /api/v1/admin/users/auth-identity-migration-reports/summary +func (h *UserHandler) GetAuthIdentityMigrationReportSummary(c *gin.Context) { + summary, err := h.adminService.GetAuthIdentityMigrationReportSummary(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, summary) +} + +// ListAuthIdentityMigrationReports returns paginated auth identity migration reports. +// GET /api/v1/admin/users/auth-identity-migration-reports +func (h *UserHandler) ListAuthIdentityMigrationReports(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + reportType := strings.TrimSpace(c.Query("report_type")) + + reports, total, err := h.adminService.ListAuthIdentityMigrationReports(c.Request.Context(), reportType, page, pageSize) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, reports, total, page, pageSize) +} + // Create handles creating a new user // POST /api/v1/admin/users func (h *UserHandler) Create(c *gin.Context) { diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 9af0fd8e..0b5aaf09 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -210,6 +210,8 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { users := admin.Group("/users") { + users.GET("/auth-identity-migration-reports/summary", h.Admin.User.GetAuthIdentityMigrationReportSummary) + users.GET("/auth-identity-migration-reports", h.Admin.User.ListAuthIdentityMigrationReports) users.GET("", h.Admin.User.List) users.GET("/:id", h.Admin.User.GetByID) users.POST("", h.Admin.User.Create) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 7c26a47c..972681a5 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -2,6 +2,8 @@ package service import ( "context" + "database/sql" + "encoding/json" "errors" "fmt" "io" @@ -16,6 +18,8 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/util/httputil" + + entsql "entgo.io/ent/dialect/sql" ) // AdminService interface defines admin management operations @@ -33,6 +37,8 @@ type AdminService interface { // codeType is optional - pass empty string to return all types. // Also returns totalRecharged (sum of all positive balance top-ups). GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) + ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error) + GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) // Group management ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) @@ -127,6 +133,19 @@ type UpdateUserInput struct { GroupRates map[int64]*float64 } +type AuthIdentityMigrationReport struct { + ID int64 `json:"id"` + ReportType string `json:"report_type"` + ReportKey string `json:"report_key"` + Details map[string]any `json:"details"` + CreatedAt time.Time `json:"created_at"` +} + +type AuthIdentityMigrationReportSummary struct { + Total int64 `json:"total"` + ByType map[string]int64 `json:"by_type"` +} + type CreateGroupInput struct { Name string Description string @@ -788,6 +807,122 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int return codes, result.Total, totalRecharged, nil } +func (s *adminServiceImpl) ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error) { + db, err := s.adminSQLDB() + if err != nil { + return nil, 0, err + } + + reportType = strings.TrimSpace(reportType) + if page <= 0 { + page = 1 + } + if pageSize <= 0 { + pageSize = 20 + } + offset := (page - 1) * pageSize + + var total int64 + if err := db.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE ($1 = '' OR report_type = $1)`, + reportType, + ).Scan(&total); err != nil { + return nil, 0, err + } + + rows, err := db.QueryContext(ctx, ` +SELECT id, report_type, report_key, details, created_at +FROM auth_identity_migration_reports +WHERE ($1 = '' OR report_type = $1) +ORDER BY created_at DESC, id DESC +LIMIT $2 OFFSET $3`, + reportType, + pageSize, + offset, + ) + if err != nil { + return nil, 0, err + } + defer func() { _ = rows.Close() }() + + reports := make([]AuthIdentityMigrationReport, 0) + for rows.Next() { + report, scanErr := scanAuthIdentityMigrationReport(rows) + if scanErr != nil { + return nil, 0, scanErr + } + reports = append(reports, report) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + return reports, total, nil +} + +func (s *adminServiceImpl) GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) { + db, err := s.adminSQLDB() + if err != nil { + return nil, err + } + + rows, err := db.QueryContext(ctx, ` +SELECT report_type, COUNT(*) +FROM auth_identity_migration_reports +GROUP BY report_type +ORDER BY report_type ASC`) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + summary := &AuthIdentityMigrationReportSummary{ + ByType: make(map[string]int64), + } + for rows.Next() { + var reportType string + var count int64 + if err := rows.Scan(&reportType, &count); err != nil { + return nil, err + } + summary.ByType[reportType] = count + summary.Total += count + } + if err := rows.Err(); err != nil { + return nil, err + } + return summary, nil +} + +func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) { + if s == nil || s.entClient == nil { + return nil, infraerrors.ServiceUnavailable("ADMIN_SQL_NOT_READY", "admin sql access is not ready") + } + driver, ok := s.entClient.Driver().(*entsql.Driver) + if !ok || driver.DB() == nil { + return nil, infraerrors.ServiceUnavailable("ADMIN_SQL_NOT_READY", "admin sql access is not ready") + } + return driver.DB(), nil +} + +func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) { + var ( + report AuthIdentityMigrationReport + details []byte + ) + if err := scanner.Scan(&report.ID, &report.ReportType, &report.ReportKey, &details, &report.CreatedAt); err != nil { + return AuthIdentityMigrationReport{}, err + } + report.Details = map[string]any{} + if len(details) > 0 { + if err := json.Unmarshal(details, &report.Details); err != nil { + return AuthIdentityMigrationReport{}, err + } + } + return report, nil +} + // Group management implementations func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} diff --git a/backend/internal/service/admin_service_identity_migration_report_test.go b/backend/internal/service/admin_service_identity_migration_report_test.go new file mode 100644 index 00000000..75ca3e5a --- /dev/null +++ b/backend/internal/service/admin_service_identity_migration_report_test.go @@ -0,0 +1,92 @@ +package service + +import ( + "context" + "database/sql" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newAdminServiceMigrationReportTestClient(t *testing.T) *dbent.Client { + t.Helper() + + db, err := sql.Open("sqlite", "file:admin_service_migration_reports?mode=memory&cache=shared&_fk=1") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + _, err = db.Exec(`CREATE TABLE auth_identity_migration_reports ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + report_type TEXT NOT NULL, + report_key TEXT NOT NULL, + details TEXT NOT NULL DEFAULT '{}', + created_at DATETIME NOT NULL + )`) + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} + +func TestAdminServiceListAuthIdentityMigrationReports(t *testing.T) { + client := newAdminServiceMigrationReportTestClient(t) + driver, ok := client.Driver().(*entsql.Driver) + require.True(t, ok) + + now := time.Now().UTC() + _, err := driver.DB().ExecContext(context.Background(), ` +INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at) +VALUES + ($1, $2, $3, $4), + ($5, $6, $7, $8)`, + "oidc_synthetic_email_requires_manual_recovery", "u-1", `{"user_id":1}`, now, + "wechat_provider_key_conflict", "u-2", `{"user_id":2}`, now.Add(-time.Minute), + ) + require.NoError(t, err) + + svc := &adminServiceImpl{entClient: client} + reports, total, err := svc.ListAuthIdentityMigrationReports(context.Background(), "oidc_synthetic_email_requires_manual_recovery", 1, 20) + require.NoError(t, err) + require.Equal(t, int64(1), total) + require.Len(t, reports, 1) + require.Equal(t, "oidc_synthetic_email_requires_manual_recovery", reports[0].ReportType) + require.Equal(t, float64(1), reports[0].Details["user_id"]) +} + +func TestAdminServiceGetAuthIdentityMigrationReportSummary(t *testing.T) { + client := newAdminServiceMigrationReportTestClient(t) + driver, ok := client.Driver().(*entsql.Driver) + require.True(t, ok) + + now := time.Now().UTC() + _, err := driver.DB().ExecContext(context.Background(), ` +INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at) +VALUES + ($1, $2, $3, $4), + ($5, $6, $7, $8), + ($9, $10, $11, $12)`, + "oidc_synthetic_email_requires_manual_recovery", "u-1", `{"user_id":1}`, now, + "wechat_provider_key_conflict", "u-2", `{"user_id":2}`, now.Add(-time.Minute), + "wechat_provider_key_conflict", "u-3", `{"user_id":3}`, now.Add(-2*time.Minute), + ) + require.NoError(t, err) + + svc := &adminServiceImpl{entClient: client} + summary, err := svc.GetAuthIdentityMigrationReportSummary(context.Background()) + require.NoError(t, err) + require.Equal(t, int64(3), summary.Total) + require.Equal(t, int64(1), summary.ByType["oidc_synthetic_email_requires_manual_recovery"]) + require.Equal(t, int64(2), summary.ByType["wechat_provider_key_conflict"]) +} From 452e55a53c7a22c3e4a0da2421d19cfb6819de96 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 22:22:14 +0800 Subject: [PATCH 049/326] feat: add admin auth identity repair binding --- .../admin/admin_basic_handlers_test.go | 48 +++- .../handler/admin/admin_service_stub_test.go | 48 ++++ .../internal/handler/admin/user_handler.go | 55 ++++ backend/internal/server/routes/admin.go | 1 + backend/internal/service/admin_service.go | 262 ++++++++++++++++++ ...dmin_service_auth_identity_binding_test.go | 215 ++++++++++++++ 6 files changed, 628 insertions(+), 1 deletion(-) create mode 100644 backend/internal/service/admin_service_auth_identity_binding_test.go diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index ff9eec7e..57620005 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -25,6 +25,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { router.GET("/api/v1/admin/users/auth-identity-migration-reports/summary", userHandler.GetAuthIdentityMigrationReportSummary) router.GET("/api/v1/admin/users/auth-identity-migration-reports", userHandler.ListAuthIdentityMigrationReports) router.GET("/api/v1/admin/users/:id", userHandler.GetByID) + router.POST("/api/v1/admin/users/:id/auth-identities", userHandler.BindAuthIdentity) router.POST("/api/v1/admin/users", userHandler.Create) router.PUT("/api/v1/admin/users/:id", userHandler.Update) router.DELETE("/api/v1/admin/users/:id", userHandler.Delete) @@ -87,8 +88,26 @@ func TestUserHandlerEndpoints(t *testing.T) { router.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) + bindBody := map[string]any{ + "provider_type": "wechat", + "provider_key": "wechat-main", + "provider_subject": "union-123", + "metadata": map[string]any{"source": "admin-repair"}, + "channel": map[string]any{ + "channel": "open", + "channel_app_id": "wx-open", + "channel_subject": "openid-123", + }, + } + body, _ := json.Marshal(bindBody) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/auth-identities", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2} - body, _ := json.Marshal(createBody) + body, _ = json.Marshal(createBody) rec = httptest.NewRecorder() req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") @@ -125,6 +144,33 @@ func TestUserHandlerEndpoints(t *testing.T) { require.Equal(t, http.StatusOK, rec.Code) } +func TestUserHandlerBindAuthIdentityMapsRequest(t *testing.T) { + router, adminSvc := setupAdminRouter() + + body, err := json.Marshal(map[string]any{ + "provider_type": "oidc", + "provider_key": "https://issuer.example", + "provider_subject": "subject-123", + "issuer": "https://issuer.example", + "metadata": map[string]any{"report_id": 12}, + }) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/9/auth-identities", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, int64(9), adminSvc.boundAuthIdentityFor) + require.NotNil(t, adminSvc.boundAuthIdentity) + require.Equal(t, "oidc", adminSvc.boundAuthIdentity.ProviderType) + require.Equal(t, "https://issuer.example", adminSvc.boundAuthIdentity.ProviderKey) + require.Equal(t, "subject-123", adminSvc.boundAuthIdentity.ProviderSubject) + require.Nil(t, adminSvc.boundAuthIdentity.Channel) + require.Equal(t, float64(12), adminSvc.boundAuthIdentity.Metadata["report_id"]) +} + func TestGroupHandlerEndpoints(t *testing.T) { router, _ := setupAdminRouter() diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 681c25c6..c8c7a247 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -18,6 +18,8 @@ type stubAdminService struct { proxyCounts []service.ProxyWithAccountCount redeems []service.RedeemCode migrationReports []service.AuthIdentityMigrationReport + boundAuthIdentity *service.AdminBindAuthIdentityInput + boundAuthIdentityFor int64 createdAccounts []*service.CreateAccountInput createdProxies []*service.CreateProxyInput updatedProxyIDs []int64 @@ -201,6 +203,52 @@ func (s *stubAdminService) GetAuthIdentityMigrationReportSummary(ctx context.Con return summary, nil } +func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) { + s.boundAuthIdentityFor = userID + copied := input + if input.Metadata != nil { + copied.Metadata = map[string]any{} + for key, value := range input.Metadata { + copied.Metadata[key] = value + } + } + if input.Channel != nil { + channel := *input.Channel + if input.Channel.Metadata != nil { + channel.Metadata = map[string]any{} + for key, value := range input.Channel.Metadata { + channel.Metadata[key] = value + } + } + copied.Channel = &channel + } + s.boundAuthIdentity = &copied + + now := time.Now().UTC() + result := &service.AdminBoundAuthIdentity{ + UserID: userID, + ProviderType: input.ProviderType, + ProviderKey: input.ProviderKey, + ProviderSubject: input.ProviderSubject, + VerifiedAt: &now, + Issuer: input.Issuer, + Metadata: input.Metadata, + CreatedAt: now, + UpdatedAt: now, + } + if input.Channel != nil { + result.Channel = &service.AdminBoundAuthIdentityChannel{ + Channel: input.Channel.Channel, + ChannelAppID: input.Channel.ChannelAppID, + ChannelSubject: input.Channel.ChannelSubject, + Metadata: input.Channel.Metadata, + CreatedAt: now, + UpdatedAt: now, + } + } + return result, nil +} + func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) { return s.groups, int64(len(s.groups)), nil } diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index ee3fbb1e..321214af 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -66,6 +66,22 @@ type UpdateBalanceRequest struct { Notes string `json:"notes"` } +type BindUserAuthIdentityRequest struct { + ProviderType string `json:"provider_type"` + ProviderKey string `json:"provider_key"` + ProviderSubject string `json:"provider_subject"` + Issuer *string `json:"issuer"` + Metadata map[string]any `json:"metadata"` + Channel *BindUserAuthIdentityChannelRequest `json:"channel"` +} + +type BindUserAuthIdentityChannelRequest struct { + Channel string `json:"channel"` + ChannelAppID string `json:"channel_app_id"` + ChannelSubject string `json:"channel_subject"` + Metadata map[string]any `json:"metadata"` +} + // List handles listing all users with pagination // GET /api/v1/admin/users // Query params: @@ -197,6 +213,45 @@ func (h *UserHandler) ListAuthIdentityMigrationReports(c *gin.Context) { response.Paginated(c, reports, total, page, pageSize) } +// BindAuthIdentity manually binds a canonical auth identity to a user. +// POST /api/v1/admin/users/:id/auth-identities +func (h *UserHandler) BindAuthIdentity(c *gin.Context) { + userID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid user ID") + return + } + + var req BindUserAuthIdentityRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + input := service.AdminBindAuthIdentityInput{ + ProviderType: req.ProviderType, + ProviderKey: req.ProviderKey, + ProviderSubject: req.ProviderSubject, + Issuer: req.Issuer, + Metadata: req.Metadata, + } + if req.Channel != nil { + input.Channel = &service.AdminBindAuthIdentityChannelInput{ + Channel: req.Channel.Channel, + ChannelAppID: req.Channel.ChannelAppID, + ChannelSubject: req.Channel.ChannelSubject, + Metadata: req.Channel.Metadata, + } + } + + result, err := h.adminService.BindUserAuthIdentity(c.Request.Context(), userID, input) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + // Create handles creating a new user // POST /api/v1/admin/users func (h *UserHandler) Create(c *gin.Context) { diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 0b5aaf09..c78fba33 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -214,6 +214,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { users.GET("/auth-identity-migration-reports", h.Admin.User.ListAuthIdentityMigrationReports) users.GET("", h.Admin.User.List) users.GET("/:id", h.Admin.User.GetByID) + users.POST("/:id/auth-identities", h.Admin.User.BindAuthIdentity) users.POST("", h.Admin.User.Create) users.PUT("/:id", h.Admin.User.Update) users.DELETE("/:id", h.Admin.User.Delete) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 972681a5..9ff26861 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -13,6 +13,8 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -39,6 +41,7 @@ type AdminService interface { GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error) GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) + BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) // Group management ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) @@ -146,6 +149,44 @@ type AuthIdentityMigrationReportSummary struct { ByType map[string]int64 `json:"by_type"` } +type AdminBindAuthIdentityInput struct { + ProviderType string + ProviderKey string + ProviderSubject string + Issuer *string + Metadata map[string]any + Channel *AdminBindAuthIdentityChannelInput +} + +type AdminBindAuthIdentityChannelInput struct { + Channel string + ChannelAppID string + ChannelSubject string + Metadata map[string]any +} + +type AdminBoundAuthIdentity struct { + UserID int64 `json:"user_id"` + ProviderType string `json:"provider_type"` + ProviderKey string `json:"provider_key"` + ProviderSubject string `json:"provider_subject"` + VerifiedAt *time.Time `json:"verified_at,omitempty"` + Issuer *string `json:"issuer,omitempty"` + Metadata map[string]any `json:"metadata"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Channel *AdminBoundAuthIdentityChannel `json:"channel,omitempty"` +} + +type AdminBoundAuthIdentityChannel struct { + Channel string `json:"channel"` + ChannelAppID string `json:"channel_app_id"` + ChannelSubject string `json:"channel_subject"` + Metadata map[string]any `json:"metadata"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + type CreateGroupInput struct { Name string Description string @@ -895,6 +936,143 @@ ORDER BY report_type ASC`) return summary, nil } +func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) { + if userID <= 0 { + return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0") + } + if s == nil || s.entClient == nil || s.userRepo == nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_UNAVAILABLE", "auth identity binding service is unavailable") + } + if _, err := s.userRepo.GetByID(ctx, userID); err != nil { + return nil, err + } + + providerType := normalizeAdminAuthIdentityProviderType(input.ProviderType) + providerKey := strings.TrimSpace(input.ProviderKey) + providerSubject := strings.TrimSpace(input.ProviderSubject) + if providerType == "" { + return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat") + } + if providerKey == "" || providerSubject == "" { + return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required") + } + + var issuer *string + if input.Issuer != nil { + trimmed := strings.TrimSpace(*input.Issuer) + if trimmed != "" { + issuer = &trimmed + } + } + + channelInput := normalizeAdminBindChannelInput(input.Channel) + if input.Channel != nil && channelInput == nil { + return nil, infraerrors.BadRequest("INVALID_INPUT", "channel, channel_app_id, and channel_subject are required when channel binding is provided") + } + + verifiedAt := time.Now().UTC() + tx, err := s.entClient.Tx(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_TX_FAILED", "failed to start auth identity bind transaction").WithCause(err) + } + defer func() { _ = tx.Rollback() }() + + identity, err := tx.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderSubjectEQ(providerSubject), + ). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + if identity != nil && identity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + + if identity == nil { + create := tx.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetProviderSubject(providerSubject). + SetVerifiedAt(verifiedAt) + if issuer != nil { + create = create.SetIssuer(*issuer) + } + if input.Metadata != nil { + create = create.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata)) + } + identity, err = create.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err) + } + } else { + update := tx.AuthIdentity.UpdateOneID(identity.ID).SetVerifiedAt(verifiedAt) + if issuer != nil { + update = update.SetIssuer(*issuer) + } + if input.Metadata != nil { + update = update.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata)) + } + identity, err = update.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err) + } + } + + var channel *dbent.AuthIdentityChannel + if channelInput != nil { + channel, err = tx.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(providerType), + authidentitychannel.ProviderKeyEQ(providerKey), + authidentitychannel.ChannelEQ(channelInput.Channel), + authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID), + authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject), + ). + WithIdentity(). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) + } + if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + if channel == nil { + create := tx.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetChannel(channelInput.Channel). + SetChannelAppID(channelInput.ChannelAppID). + SetChannelSubject(channelInput.ChannelSubject) + if channelInput.Metadata != nil { + create = create.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata)) + } + channel, err = create.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err) + } + } else { + update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).SetIdentityID(identity.ID) + if channelInput.Metadata != nil { + update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata)) + } + channel, err = update.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err) + } + } + } + + if err := tx.Commit(); err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_COMMIT_FAILED", "failed to commit auth identity bind").WithCause(err) + } + return buildAdminBoundAuthIdentity(identity, channel), nil +} + func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) { if s == nil || s.entClient == nil { return nil, infraerrors.ServiceUnavailable("ADMIN_SQL_NOT_READY", "admin sql access is not ready") @@ -906,6 +1084,90 @@ func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) { return driver.DB(), nil } +func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput { + if input == nil { + return nil + } + channel := &AdminBindAuthIdentityChannelInput{ + Channel: strings.TrimSpace(input.Channel), + ChannelAppID: strings.TrimSpace(input.ChannelAppID), + ChannelSubject: strings.TrimSpace(input.ChannelSubject), + Metadata: cloneAdminAuthIdentityMetadata(input.Metadata), + } + if channel.Channel == "" || channel.ChannelAppID == "" || channel.ChannelSubject == "" { + return nil + } + return channel +} + +func normalizeAdminAuthIdentityProviderType(input string) string { + switch strings.ToLower(strings.TrimSpace(input)) { + case "email": + return "email" + case "linuxdo": + return "linuxdo" + case "oidc": + return "oidc" + case "wechat": + return "wechat" + default: + return "" + } +} + +func buildAdminBoundAuthIdentity(identity *dbent.AuthIdentity, channel *dbent.AuthIdentityChannel) *AdminBoundAuthIdentity { + if identity == nil { + return nil + } + result := &AdminBoundAuthIdentity{ + UserID: identity.UserID, + ProviderType: strings.TrimSpace(identity.ProviderType), + ProviderKey: strings.TrimSpace(identity.ProviderKey), + ProviderSubject: strings.TrimSpace(identity.ProviderSubject), + VerifiedAt: identity.VerifiedAt, + Issuer: identity.Issuer, + Metadata: cloneAdminAuthIdentityMetadata(identity.Metadata), + CreatedAt: identity.CreatedAt, + UpdatedAt: identity.UpdatedAt, + } + if channel != nil { + result.Channel = &AdminBoundAuthIdentityChannel{ + Channel: strings.TrimSpace(channel.Channel), + ChannelAppID: strings.TrimSpace(channel.ChannelAppID), + ChannelSubject: strings.TrimSpace(channel.ChannelSubject), + Metadata: cloneAdminAuthIdentityMetadata(channel.Metadata), + CreatedAt: channel.CreatedAt, + UpdatedAt: channel.UpdatedAt, + } + } + return result +} + +func cloneAdminAuthIdentityMetadata(input map[string]any) map[string]any { + if input == nil { + return nil + } + if len(input) == 0 { + return map[string]any{} + } + data, err := json.Marshal(input) + if err != nil { + out := make(map[string]any, len(input)) + for key, value := range input { + out[key] = value + } + return out + } + var out map[string]any + if err := json.Unmarshal(data, &out); err != nil { + out = make(map[string]any, len(input)) + for key, value := range input { + out[key] = value + } + } + return out +} + func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) { var ( report AuthIdentityMigrationReport diff --git a/backend/internal/service/admin_service_auth_identity_binding_test.go b/backend/internal/service/admin_service_auth_identity_binding_test.go new file mode 100644 index 00000000..f8ce3935 --- /dev/null +++ b/backend/internal/service/admin_service_auth_identity_binding_test.go @@ -0,0 +1,215 @@ +//go:build unit + +package service + +import ( + "context" + "database/sql" + "testing" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/enttest" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newAdminServiceAuthIdentityBindingTestClient(t *testing.T) *dbent.Client { + t.Helper() + + db, err := sql.Open("sqlite", "file:admin_service_auth_identity_binding?mode=memory&cache=shared&_fk=1") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} + +func TestAdminServiceBindUserAuthIdentityCreatesCanonicalAndChannelBinding(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("bind-target@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}}, + entClient: client, + } + + result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-123", + Metadata: map[string]any{"source": "admin-repair"}, + Channel: &AdminBindAuthIdentityChannelInput{ + Channel: "open", + ChannelAppID: "wx-open", + ChannelSubject: "openid-123", + Metadata: map[string]any{"scene": "migration"}, + }, + }) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, user.ID, result.UserID) + require.Equal(t, "wechat", result.ProviderType) + require.Equal(t, "wechat-main", result.ProviderKey) + require.NotNil(t, result.VerifiedAt) + require.NotNil(t, result.Channel) + require.Equal(t, "open", result.Channel.Channel) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ("wechat-main"), + authidentity.ProviderSubjectEQ("union-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) + require.Equal(t, "admin-repair", identity.Metadata["source"]) + require.NotNil(t, identity.VerifiedAt) + + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ("wechat-main"), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, identity.ID, channel.IdentityID) + require.Equal(t, "migration", channel.Metadata["scene"]) +} + +func TestAdminServiceBindUserAuthIdentityRejectsOtherOwner(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + owner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + target, err := client.User.Create(). + SetEmail("target@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(owner.ID). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("subject-1"). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: target.ID, Email: target.Email, Status: StatusActive}}, + entClient: client, + } + + _, err = svc.BindUserAuthIdentity(ctx, target.ID, AdminBindAuthIdentityInput{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-1", + }) + require.Error(t, err) + require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", infraerrors.Reason(err)) +} + +func TestAdminServiceBindUserAuthIdentityIsIdempotentForSameUser(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("same-user@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}}, + entClient: client, + } + + first, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-2", + Metadata: map[string]any{"source": "first"}, + }) + require.NoError(t, err) + + second, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-2", + Metadata: map[string]any{"source": "second"}, + }) + require.NoError(t, err) + require.Equal(t, first.UserID, second.UserID) + require.Equal(t, "second", second.Metadata["source"]) + + identities, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("subject-2"), + ). + All(ctx) + require.NoError(t, err) + require.Len(t, identities, 1) + require.Equal(t, "second", identities[0].Metadata["source"]) +} + +func TestAdminServiceBindUserAuthIdentityRejectsInvalidProviderType(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("invalid-provider@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}}, + entClient: client, + } + + _, err = svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "github", + ProviderKey: "github-main", + ProviderSubject: "subject-3", + }) + require.Error(t, err) + require.Equal(t, "INVALID_INPUT", infraerrors.Reason(err)) +} From 724f8e89a1d8a61b9847883dea8573894b6b5c02 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 22:29:21 +0800 Subject: [PATCH 050/326] feat: resolve auth identity migration reports --- .../admin/admin_basic_handlers_test.go | 14 ++- .../handler/admin/admin_service_stub_test.go | 14 +++ .../internal/handler/admin/user_handler.go | 39 ++++++ backend/internal/server/routes/admin.go | 1 + backend/internal/service/admin_service.go | 111 ++++++++++++++++-- ..._service_identity_migration_report_test.go | 34 +++++- ...h_identity_migration_report_resolution.sql | 11 ++ 7 files changed, 209 insertions(+), 15 deletions(-) create mode 100644 backend/migrations/114_auth_identity_migration_report_resolution.sql diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index 57620005..207931d9 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "testing" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -24,6 +25,10 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { router.GET("/api/v1/admin/users", userHandler.List) router.GET("/api/v1/admin/users/auth-identity-migration-reports/summary", userHandler.GetAuthIdentityMigrationReportSummary) router.GET("/api/v1/admin/users/auth-identity-migration-reports", userHandler.ListAuthIdentityMigrationReports) + router.POST("/api/v1/admin/users/auth-identity-migration-reports/:id/resolve", func(c *gin.Context) { + c.Set(string(servermiddleware.ContextKeyUser), servermiddleware.AuthSubject{UserID: 99}) + userHandler.ResolveAuthIdentityMigrationReport(c) + }) router.GET("/api/v1/admin/users/:id", userHandler.GetByID) router.POST("/api/v1/admin/users/:id/auth-identities", userHandler.BindAuthIdentity) router.POST("/api/v1/admin/users", userHandler.Create) @@ -83,6 +88,13 @@ func TestUserHandlerEndpoints(t *testing.T) { router.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) + body, _ := json.Marshal(map[string]any{"resolution_note": "resolved by manual bind"}) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/auth-identity-migration-reports/1/resolve", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + rec = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil) router.ServeHTTP(rec, req) @@ -99,7 +111,7 @@ func TestUserHandlerEndpoints(t *testing.T) { "channel_subject": "openid-123", }, } - body, _ := json.Marshal(bindBody) + body, _ = json.Marshal(bindBody) rec = httptest.NewRecorder() req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/auth-identities", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index c8c7a247..8ecadfdf 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -249,6 +249,20 @@ func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int6 return result, nil } +func (s *stubAdminService) ResolveAuthIdentityMigrationReport(ctx context.Context, reportID, resolvedByUserID int64, resolutionNote string) (*service.AuthIdentityMigrationReport, error) { + now := time.Now().UTC() + for i := range s.migrationReports { + if s.migrationReports[i].ID != reportID { + continue + } + s.migrationReports[i].ResolvedAt = &now + s.migrationReports[i].ResolvedByUserID = &resolvedByUserID + s.migrationReports[i].ResolutionNote = resolutionNote + return &s.migrationReports[i], nil + } + return nil, nil +} + func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) { return s.groups, int64(len(s.groups)), nil } diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 321214af..e582e322 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -7,6 +7,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -82,6 +83,10 @@ type BindUserAuthIdentityChannelRequest struct { Metadata map[string]any `json:"metadata"` } +type ResolveAuthIdentityMigrationReportRequest struct { + ResolutionNote string `json:"resolution_note"` +} + // List handles listing all users with pagination // GET /api/v1/admin/users // Query params: @@ -252,6 +257,40 @@ func (h *UserHandler) BindAuthIdentity(c *gin.Context) { response.Success(c, result) } +// ResolveAuthIdentityMigrationReport marks a migration report as resolved. +// POST /api/v1/admin/users/auth-identity-migration-reports/:id/resolve +func (h *UserHandler) ResolveAuthIdentityMigrationReport(c *gin.Context) { + reportID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid report ID") + return + } + + subject, ok := servermiddleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Unauthorized(c, "Authentication required") + return + } + + var req ResolveAuthIdentityMigrationReportRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + report, err := h.adminService.ResolveAuthIdentityMigrationReport( + c.Request.Context(), + reportID, + subject.UserID, + req.ResolutionNote, + ) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, report) +} + // Create handles creating a new user // POST /api/v1/admin/users func (h *UserHandler) Create(c *gin.Context) { diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index c78fba33..e5c0eac1 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -212,6 +212,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { { users.GET("/auth-identity-migration-reports/summary", h.Admin.User.GetAuthIdentityMigrationReportSummary) users.GET("/auth-identity-migration-reports", h.Admin.User.ListAuthIdentityMigrationReports) + users.POST("/auth-identity-migration-reports/:id/resolve", h.Admin.User.ResolveAuthIdentityMigrationReport) users.GET("", h.Admin.User.List) users.GET("/:id", h.Admin.User.GetByID) users.POST("/:id/auth-identities", h.Admin.User.BindAuthIdentity) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 9ff26861..3490374e 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -42,6 +42,7 @@ type AdminService interface { ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error) GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) + ResolveAuthIdentityMigrationReport(ctx context.Context, reportID, resolvedByUserID int64, resolutionNote string) (*AuthIdentityMigrationReport, error) // Group management ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) @@ -137,16 +138,21 @@ type UpdateUserInput struct { } type AuthIdentityMigrationReport struct { - ID int64 `json:"id"` - ReportType string `json:"report_type"` - ReportKey string `json:"report_key"` - Details map[string]any `json:"details"` - CreatedAt time.Time `json:"created_at"` + ID int64 `json:"id"` + ReportType string `json:"report_type"` + ReportKey string `json:"report_key"` + Details map[string]any `json:"details"` + CreatedAt time.Time `json:"created_at"` + ResolvedAt *time.Time `json:"resolved_at,omitempty"` + ResolvedByUserID *int64 `json:"resolved_by_user_id,omitempty"` + ResolutionNote string `json:"resolution_note,omitempty"` } type AuthIdentityMigrationReportSummary struct { - Total int64 `json:"total"` - ByType map[string]int64 `json:"by_type"` + Total int64 `json:"total"` + OpenTotal int64 `json:"open_total"` + ResolvedTotal int64 `json:"resolved_total"` + ByType map[string]int64 `json:"by_type"` } type AdminBindAuthIdentityInput struct { @@ -874,7 +880,7 @@ WHERE ($1 = '' OR report_type = $1)`, } rows, err := db.QueryContext(ctx, ` -SELECT id, report_type, report_key, details, created_at +SELECT id, report_type, report_key, details, created_at, resolved_at, resolved_by_user_id, resolution_note FROM auth_identity_migration_reports WHERE ($1 = '' OR report_type = $1) ORDER BY created_at DESC, id DESC @@ -909,7 +915,11 @@ func (s *adminServiceImpl) GetAuthIdentityMigrationReportSummary(ctx context.Con } rows, err := db.QueryContext(ctx, ` -SELECT report_type, COUNT(*) +SELECT + report_type, + COUNT(*), + SUM(CASE WHEN resolved_at IS NULL THEN 1 ELSE 0 END), + SUM(CASE WHEN resolved_at IS NOT NULL THEN 1 ELSE 0 END) FROM auth_identity_migration_reports GROUP BY report_type ORDER BY report_type ASC`) @@ -924,11 +934,15 @@ ORDER BY report_type ASC`) for rows.Next() { var reportType string var count int64 - if err := rows.Scan(&reportType, &count); err != nil { + var openCount int64 + var resolvedCount int64 + if err := rows.Scan(&reportType, &count, &openCount, &resolvedCount); err != nil { return nil, err } summary.ByType[reportType] = count summary.Total += count + summary.OpenTotal += openCount + summary.ResolvedTotal += resolvedCount } if err := rows.Err(); err != nil { return nil, err @@ -936,6 +950,56 @@ ORDER BY report_type ASC`) return summary, nil } +func (s *adminServiceImpl) ResolveAuthIdentityMigrationReport(ctx context.Context, reportID, resolvedByUserID int64, resolutionNote string) (*AuthIdentityMigrationReport, error) { + if reportID <= 0 { + return nil, infraerrors.BadRequest("INVALID_INPUT", "report id must be greater than 0") + } + if resolvedByUserID <= 0 { + return nil, infraerrors.BadRequest("INVALID_INPUT", "resolved_by_user_id must be greater than 0") + } + + db, err := s.adminSQLDB() + if err != nil { + return nil, err + } + + now := time.Now().UTC() + result, err := db.ExecContext(ctx, ` +UPDATE auth_identity_migration_reports +SET + resolved_at = COALESCE(resolved_at, $2), + resolved_by_user_id = COALESCE(resolved_by_user_id, $3), + resolution_note = $4 +WHERE id = $1`, + reportID, + now, + resolvedByUserID, + strings.TrimSpace(resolutionNote), + ) + if err != nil { + return nil, err + } + affected, err := result.RowsAffected() + if err != nil { + return nil, err + } + if affected == 0 { + return nil, infraerrors.NotFound("AUTH_IDENTITY_MIGRATION_REPORT_NOT_FOUND", "auth identity migration report not found") + } + + row := db.QueryRowContext(ctx, ` +SELECT id, report_type, report_key, details, created_at, resolved_at, resolved_by_user_id, resolution_note +FROM auth_identity_migration_reports +WHERE id = $1`, + reportID, + ) + report, err := scanAuthIdentityMigrationReport(row) + if err != nil { + return nil, err + } + return &report, nil +} + func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) { if userID <= 0 { return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0") @@ -1170,10 +1234,22 @@ func cloneAdminAuthIdentityMetadata(input map[string]any) map[string]any { func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) { var ( - report AuthIdentityMigrationReport - details []byte + report AuthIdentityMigrationReport + details []byte + resolvedAt sql.NullTime + resolvedByUserID sql.NullInt64 + resolutionNote sql.NullString ) - if err := scanner.Scan(&report.ID, &report.ReportType, &report.ReportKey, &details, &report.CreatedAt); err != nil { + if err := scanner.Scan( + &report.ID, + &report.ReportType, + &report.ReportKey, + &details, + &report.CreatedAt, + &resolvedAt, + &resolvedByUserID, + &resolutionNote, + ); err != nil { return AuthIdentityMigrationReport{}, err } report.Details = map[string]any{} @@ -1182,6 +1258,15 @@ func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error return AuthIdentityMigrationReport{}, err } } + if resolvedAt.Valid { + report.ResolvedAt = &resolvedAt.Time + } + if resolvedByUserID.Valid { + report.ResolvedByUserID = &resolvedByUserID.Int64 + } + if resolutionNote.Valid { + report.ResolutionNote = resolutionNote.String + } return report, nil } diff --git a/backend/internal/service/admin_service_identity_migration_report_test.go b/backend/internal/service/admin_service_identity_migration_report_test.go index 75ca3e5a..6975604b 100644 --- a/backend/internal/service/admin_service_identity_migration_report_test.go +++ b/backend/internal/service/admin_service_identity_migration_report_test.go @@ -30,7 +30,10 @@ func newAdminServiceMigrationReportTestClient(t *testing.T) *dbent.Client { report_type TEXT NOT NULL, report_key TEXT NOT NULL, details TEXT NOT NULL DEFAULT '{}', - created_at DATETIME NOT NULL + created_at DATETIME NOT NULL, + resolved_at DATETIME NULL, + resolved_by_user_id INTEGER NULL, + resolution_note TEXT NOT NULL DEFAULT '' )`) require.NoError(t, err) @@ -87,6 +90,35 @@ VALUES summary, err := svc.GetAuthIdentityMigrationReportSummary(context.Background()) require.NoError(t, err) require.Equal(t, int64(3), summary.Total) + require.Equal(t, int64(3), summary.OpenTotal) + require.Zero(t, summary.ResolvedTotal) require.Equal(t, int64(1), summary.ByType["oidc_synthetic_email_requires_manual_recovery"]) require.Equal(t, int64(2), summary.ByType["wechat_provider_key_conflict"]) } + +func TestAdminServiceResolveAuthIdentityMigrationReport(t *testing.T) { + client := newAdminServiceMigrationReportTestClient(t) + driver, ok := client.Driver().(*entsql.Driver) + require.True(t, ok) + + now := time.Now().UTC() + _, err := driver.DB().ExecContext(context.Background(), ` +INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at) +VALUES ($1, $2, $3, $4)`, + "oidc_synthetic_email_requires_manual_recovery", "u-1", `{"user_id":1}`, now, + ) + require.NoError(t, err) + + svc := &adminServiceImpl{entClient: client} + report, err := svc.ResolveAuthIdentityMigrationReport(context.Background(), 1, 99, "resolved by admin binding") + require.NoError(t, err) + require.NotNil(t, report.ResolvedAt) + require.NotNil(t, report.ResolvedByUserID) + require.Equal(t, int64(99), *report.ResolvedByUserID) + require.Equal(t, "resolved by admin binding", report.ResolutionNote) + + summary, err := svc.GetAuthIdentityMigrationReportSummary(context.Background()) + require.NoError(t, err) + require.Zero(t, summary.OpenTotal) + require.Equal(t, int64(1), summary.ResolvedTotal) +} diff --git a/backend/migrations/114_auth_identity_migration_report_resolution.sql b/backend/migrations/114_auth_identity_migration_report_resolution.sql new file mode 100644 index 00000000..f84bf822 --- /dev/null +++ b/backend/migrations/114_auth_identity_migration_report_resolution.sql @@ -0,0 +1,11 @@ +ALTER TABLE auth_identity_migration_reports + ADD COLUMN IF NOT EXISTS resolved_at TIMESTAMPTZ NULL; + +ALTER TABLE auth_identity_migration_reports + ADD COLUMN IF NOT EXISTS resolved_by_user_id BIGINT NULL; + +ALTER TABLE auth_identity_migration_reports + ADD COLUMN IF NOT EXISTS resolution_note TEXT NOT NULL DEFAULT ''; + +CREATE INDEX IF NOT EXISTS idx_auth_identity_migration_reports_resolved_at + ON auth_identity_migration_reports (resolved_at); From bffcc2042eeeb61eb0e5c9fc0c7cd4d40ff89f96 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 22:37:25 +0800 Subject: [PATCH 051/326] fix: complete oidc pending auth callback flows --- frontend/src/views/auth/OidcCallbackView.vue | 440 +++++++++++++----- .../auth/__tests__/OidcCallbackView.spec.ts | 243 ++++++---- 2 files changed, 495 insertions(+), 188 deletions(-) diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue index deb40b20..de3f2e40 100644 --- a/frontend/src/views/auth/OidcCallbackView.vue +++ b/frontend/src/views/auth/OidcCallbackView.vue @@ -18,9 +18,10 @@
@@ -108,39 +109,6 @@ - - - - + + + + + +
@@ -177,11 +263,12 @@ diff --git a/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts new file mode 100644 index 00000000..cfbd9f1c --- /dev/null +++ b/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts @@ -0,0 +1,80 @@ +import { flushPromises, mount } from '@vue/test-utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import WechatPaymentCallbackView from '@/views/auth/WechatPaymentCallbackView.vue' + +const { replaceMock, routeState, locationState } = vi.hoisted(() => ({ + replaceMock: vi.fn(), + routeState: { + query: {} as Record, + }, + locationState: { + current: { + href: 'http://localhost/auth/wechat/payment/callback', + hash: '', + search: '', + pathname: '/auth/wechat/payment/callback', + origin: 'http://localhost', + } as Location & { origin: string }, + }, +})) + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, + useRouter: () => ({ + replace: replaceMock, + }), +})) + +vi.mock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string) => key, + locale: { value: 'zh-CN' }, + }), +})) + +describe('WechatPaymentCallbackView', () => { + beforeEach(() => { + replaceMock.mockReset() + routeState.query = {} + locationState.current = { + href: 'http://localhost/auth/wechat/payment/callback', + hash: '', + search: '', + pathname: '/auth/wechat/payment/callback', + origin: 'http://localhost', + } as Location & { origin: string } + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState.current, + }) + }) + + it('redirects back to purchase with openid and payment context from hash fragment', async () => { + locationState.current.hash = '#openid=openid-123&payment_type=wxpay&amount=12.5&order_type=balance&redirect=%2Fpurchase%3Ffrom%3Dwechat' + + mount(WechatPaymentCallbackView) + await flushPromises() + + expect(replaceMock).toHaveBeenCalledWith({ + path: '/purchase', + query: { + from: 'wechat', + wechat_resume: '1', + openid: 'openid-123', + payment_type: 'wxpay', + amount: '12.5', + order_type: 'balance', + }, + }) + }) + + it('shows an error when the callback payload is missing openid', async () => { + locationState.current.hash = '#payment_type=wxpay' + + const wrapper = mount(WechatPaymentCallbackView) + await flushPromises() + + expect(replaceMock).not.toHaveBeenCalled() + expect(wrapper.text()).toContain('微信支付回调缺少 openid。') + }) +}) diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue index f973ad5b..bfb9dae2 100644 --- a/frontend/src/views/user/PaymentView.vue +++ b/frontend/src/views/user/PaymentView.vue @@ -309,6 +309,20 @@ const previewImage = ref('') const paymentPhase = ref<'select' | 'paying'>('select') +interface CreateOrderOptions { + openid?: string + paymentType?: string + isResume?: boolean +} + +interface WeixinJSBridgeLike { + invoke( + action: string, + payload: Record, + callback: (result: Record) => void, + ): void +} + function emptyPaymentState(): PaymentRecoverySnapshot { return { orderId: 0, @@ -326,6 +340,48 @@ function emptyPaymentState(): PaymentRecoverySnapshot { } } +function readRouteQueryValue(value: unknown): string { + if (Array.isArray(value)) { + return typeof value[0] === 'string' ? value[0] : '' + } + return typeof value === 'string' ? value : '' +} + +function getWeixinJSBridge(): WeixinJSBridgeLike | undefined { + return (window as Window & { WeixinJSBridge?: WeixinJSBridgeLike }).WeixinJSBridge +} + +function waitForWeixinJSBridge(timeoutMs = 4000): Promise { + const existing = getWeixinJSBridge() + if (existing) return Promise.resolve(existing) + + return new Promise((resolve) => { + let settled = false + const finish = (bridge: WeixinJSBridgeLike | null) => { + if (settled) return + settled = true + document.removeEventListener('WeixinJSBridgeReady', handleReady) + document.removeEventListener('onWeixinJSBridgeReady', handleReady) + window.clearTimeout(timer) + resolve(bridge) + } + const handleReady = () => finish(getWeixinJSBridge() ?? null) + const timer = window.setTimeout(() => finish(getWeixinJSBridge() ?? null), timeoutMs) + document.addEventListener('WeixinJSBridgeReady', handleReady, false) + document.addEventListener('onWeixinJSBridgeReady', handleReady, false) + }) +} + +async function invokeWechatJsapiPayment(payload: Record): Promise> { + const bridge = await waitForWeixinJSBridge() + if (!bridge) { + throw new Error('WeixinJSBridge is unavailable') + } + return new Promise((resolve) => { + bridge.invoke('getBrandWCPayRequest', payload, (result) => resolve(result || {})) + }) +} + const paymentState = ref(emptyPaymentState()) function persistRecoverySnapshot(snapshot: PaymentRecoverySnapshot) { @@ -560,25 +616,32 @@ async function confirmSubscribe() { await createOrder(selectedPlan.value.price, 'subscription', selectedPlan.value.id) } -async function createOrder(orderAmount: number, orderType: OrderType, planId?: number) { +async function createOrder(orderAmount: number, orderType: OrderType, planId?: number, options: CreateOrderOptions = {}) { submitting.value = true errorMessage.value = '' try { - const result = await paymentStore.createOrder(buildCreateOrderPayload({ + const requestType = normalizeVisibleMethod(options.paymentType || selectedMethod.value) || options.paymentType || selectedMethod.value + const payload = buildCreateOrderPayload({ amount: orderAmount, - paymentType: selectedMethod.value, + paymentType: requestType, orderType, planId, origin: typeof window !== 'undefined' ? window.location.origin : '', isWechatBrowser: typeof window !== 'undefined' && /MicroMessenger/i.test(window.navigator.userAgent), - })) as CreateOrderResult & { resume_token?: string } + }) + if (options.openid) { + payload.openid = options.openid + } + payload.is_mobile = isMobileDevice() + + const result = await paymentStore.createOrder(payload) as CreateOrderResult & { resume_token?: string } const openWindow = (url: string, features = POPUP_WINDOW_FEATURES) => { const win = window.open(url, 'paymentPopup', features) if (!win || win.closed) { window.location.href = url } } - const visibleMethod = normalizeVisibleMethod(selectedMethod.value) || selectedMethod.value + const visibleMethod = normalizeVisibleMethod(requestType) || requestType const stripeMethod = visibleMethod === 'wxpay' ? 'wechat_pay' : 'alipay' const stripeRouteUrl = result.client_secret ? router.resolve({ @@ -599,6 +662,11 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n stripeRouteUrl, }) + if (decision.kind === 'wechat_oauth' && decision.oauth?.authorize_url) { + window.location.href = decision.oauth.authorize_url + return + } + if (decision.kind === 'unhandled') { errorMessage.value = t('payment.result.failed') appStore.showError(errorMessage.value) @@ -617,6 +685,16 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n window.location.href = decision.paymentState.payUrl return } + if (decision.kind === 'wechat_jsapi' && decision.jsapi) { + const jsapiResult = await invokeWechatJsapiPayment(decision.jsapi as Record) + const errMsg = String(jsapiResult.err_msg || '').toLowerCase() + if (errMsg.includes('cancel')) { + appStore.showInfo(t('payment.qr.cancelled')) + } else if (errMsg && !errMsg.includes('ok')) { + appStore.showError(t('payment.result.failed')) + } + return + } if (decision.kind === 'redirect_waiting' && decision.paymentState.payUrl) { if (isMobileDevice()) { window.location.href = decision.paymentState.payUrl @@ -640,6 +718,50 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n } } +async function resumeWechatPaymentFromQuery() { + const openid = readRouteQueryValue(route.query.openid) + if (readRouteQueryValue(route.query.wechat_resume) !== '1' || !openid) { + return + } + + const paymentType = normalizeVisibleMethod(readRouteQueryValue(route.query.payment_type)) || 'wxpay' + const orderType = readRouteQueryValue(route.query.order_type) === 'subscription' ? 'subscription' : 'balance' + const planId = Number.parseInt(readRouteQueryValue(route.query.plan_id), 10) + const rawAmount = Number.parseFloat(readRouteQueryValue(route.query.amount)) + const orderAmount = Number.isFinite(rawAmount) && rawAmount > 0 + ? rawAmount + : (orderType === 'subscription' + ? (checkout.value.plans.find(plan => plan.id === planId)?.price ?? 0) + : validAmount.value) + + selectedMethod.value = paymentType + if (orderType === 'balance' && orderAmount > 0) { + amount.value = orderAmount + } + if (orderType === 'subscription' && Number.isFinite(planId) && planId > 0) { + selectedPlan.value = checkout.value.plans.find(plan => plan.id === planId) ?? null + } + + const nextQuery = { ...route.query } + delete nextQuery.wechat_resume + delete nextQuery.openid + delete nextQuery.state + delete nextQuery.scope + delete nextQuery.payment_type + delete nextQuery.amount + delete nextQuery.order_type + delete nextQuery.plan_id + await router.replace({ path: route.path, query: nextQuery }) + + if (orderAmount > 0) { + await createOrder(orderAmount, orderType, Number.isFinite(planId) && planId > 0 ? planId : undefined, { + openid, + paymentType, + isResume: true, + }) + } +} + onMounted(async () => { try { const res = await paymentAPI.getCheckoutInfo() @@ -672,6 +794,7 @@ onMounted(async () => { removeRecoverySnapshot() } } + await resumeWechatPaymentFromQuery() if (checkout.value.balance_disabled) { activeTab.value = 'subscription' } From a1425b457d91f59c526a2f56cc4f7c78b8c3c3c6 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 20 Apr 2026 23:38:59 +0800 Subject: [PATCH 057/326] feat(channel-monitor): redesign user dashboard as card grid MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reference check-cx UI: INTELLIGENCE MONITOR hero + 3-column card grid with 60-point timeline bars. Backend: - Add PrimaryPingLatencyMs + Timeline[60] to UserMonitorView - ListRecentHistoryForMonitors: batch CTE + ROW_NUMBER() window query - indexLatestByModel / indexAvailabilityByModel helpers Frontend: - 7 new components: ProviderIcon, MonitorMetricPair, MonitorAvailabilityRow, MonitorTimeline, MonitorHero, MonitorCard, MonitorCardGrid - ChannelStatusView 381→~180 lines (delegated to subcomponents) - AbortController reload concurrency protection - HSL 0-120° availability color mapping - Replace emoji with Icon component (bolt / globe) - i18n: monitorCommon.* shared namespace, channelStatus.hero.* Bump VERSION to 0.1.114.24 --- .../handler/channel_monitor_user_handler.go | 60 +++-- .../repository/channel_monitor_repo.go | 135 +++++++++- .../service/channel_monitor_aggregator.go | 95 ++++++- .../internal/service/channel_monitor_const.go | 3 + .../service/channel_monitor_service.go | 3 + .../internal/service/channel_monitor_types.go | 37 ++- frontend/src/api/channelMonitor.ts | 9 + .../user/MonitorPrimaryModelCell.vue | 71 ----- .../user/monitor/MonitorAvailabilityRow.vue | 49 ++++ .../components/user/monitor/MonitorCard.vue | 128 +++++++++ .../user/monitor/MonitorCardGrid.vue | 81 ++++++ .../components/user/monitor/MonitorHero.vue | 133 +++++++++ .../user/monitor/MonitorMetricPair.vue | 45 ++++ .../user/monitor/MonitorTimeline.vue | 113 ++++++++ .../components/user/monitor/ProviderIcon.vue | 71 +++++ .../composables/useChannelMonitorFormat.ts | 60 ++++- frontend/src/i18n/locales/en.ts | 33 ++- frontend/src/i18n/locales/zh.ts | 33 ++- frontend/src/views/user/ChannelStatusView.vue | 253 ++++++++---------- 19 files changed, 1134 insertions(+), 278 deletions(-) delete mode 100644 frontend/src/components/user/MonitorPrimaryModelCell.vue create mode 100644 frontend/src/components/user/monitor/MonitorAvailabilityRow.vue create mode 100644 frontend/src/components/user/monitor/MonitorCard.vue create mode 100644 frontend/src/components/user/monitor/MonitorCardGrid.vue create mode 100644 frontend/src/components/user/monitor/MonitorHero.vue create mode 100644 frontend/src/components/user/monitor/MonitorMetricPair.vue create mode 100644 frontend/src/components/user/monitor/MonitorTimeline.vue create mode 100644 frontend/src/components/user/monitor/ProviderIcon.vue diff --git a/backend/internal/handler/channel_monitor_user_handler.go b/backend/internal/handler/channel_monitor_user_handler.go index a031b4a2..6a513dc1 100644 --- a/backend/internal/handler/channel_monitor_user_handler.go +++ b/backend/internal/handler/channel_monitor_user_handler.go @@ -1,6 +1,8 @@ package handler import ( + "time" + "github.com/Wei-Shaw/sub2api/internal/handler/admin" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -22,15 +24,26 @@ func NewChannelMonitorUserHandler(monitorService *service.ChannelMonitorService) // --- Response --- type channelMonitorUserListItem struct { - ID int64 `json:"id"` - Name string `json:"name"` - Provider string `json:"provider"` - GroupName string `json:"group_name"` - PrimaryModel string `json:"primary_model"` - PrimaryStatus string `json:"primary_status"` - PrimaryLatencyMs *int `json:"primary_latency_ms"` - Availability7d float64 `json:"availability_7d"` - ExtraModels []dto.ChannelMonitorExtraModelStatus `json:"extra_models"` + ID int64 `json:"id"` + Name string `json:"name"` + Provider string `json:"provider"` + GroupName string `json:"group_name"` + PrimaryModel string `json:"primary_model"` + PrimaryStatus string `json:"primary_status"` + PrimaryLatencyMs *int `json:"primary_latency_ms"` + PrimaryPingLatencyMs *int `json:"primary_ping_latency_ms"` + Availability7d float64 `json:"availability_7d"` + ExtraModels []dto.ChannelMonitorExtraModelStatus `json:"extra_models"` + Timeline []channelMonitorUserTimelinePoint `json:"timeline"` +} + +// channelMonitorUserTimelinePoint 主模型最近一次检测的 timeline 点。 +// 仅用于用户视图 list 响应,admin 视图不使用。 +type channelMonitorUserTimelinePoint struct { + Status string `json:"status"` + LatencyMs *int `json:"latency_ms"` + PingLatencyMs *int `json:"ping_latency_ms"` + CheckedAt string `json:"checked_at"` } type channelMonitorUserDetailResponse struct { @@ -60,16 +73,27 @@ func userMonitorViewToItem(v *service.UserMonitorView) channelMonitorUserListIte LatencyMs: e.LatencyMs, }) } + timeline := make([]channelMonitorUserTimelinePoint, 0, len(v.Timeline)) + for _, p := range v.Timeline { + timeline = append(timeline, channelMonitorUserTimelinePoint{ + Status: p.Status, + LatencyMs: p.LatencyMs, + PingLatencyMs: p.PingLatencyMs, + CheckedAt: p.CheckedAt.UTC().Format(time.RFC3339), + }) + } return channelMonitorUserListItem{ - ID: v.ID, - Name: v.Name, - Provider: v.Provider, - GroupName: v.GroupName, - PrimaryModel: v.PrimaryModel, - PrimaryStatus: v.PrimaryStatus, - PrimaryLatencyMs: v.PrimaryLatencyMs, - Availability7d: v.Availability7d, - ExtraModels: extras, + ID: v.ID, + Name: v.Name, + Provider: v.Provider, + GroupName: v.GroupName, + PrimaryModel: v.PrimaryModel, + PrimaryStatus: v.PrimaryStatus, + PrimaryLatencyMs: v.PrimaryLatencyMs, + PrimaryPingLatencyMs: v.PrimaryPingLatencyMs, + Availability7d: v.Availability7d, + ExtraModels: extras, + Timeline: timeline, } } diff --git a/backend/internal/repository/channel_monitor_repo.go b/backend/internal/repository/channel_monitor_repo.go index b943f33c..cf5e1a93 100644 --- a/backend/internal/repository/channel_monitor_repo.go +++ b/backend/internal/repository/channel_monitor_repo.go @@ -243,7 +243,7 @@ func (r *channelMonitorRepository) ListHistory(ctx context.Context, monitorID in func (r *channelMonitorRepository) ListLatestPerModel(ctx context.Context, monitorID int64) ([]*service.ChannelMonitorLatest, error) { const q = ` SELECT DISTINCT ON (model) - model, status, latency_ms, checked_at + model, status, latency_ms, ping_latency_ms, checked_at FROM channel_monitor_histories WHERE monitor_id = $1 ORDER BY model, checked_at DESC @@ -257,19 +257,27 @@ func (r *channelMonitorRepository) ListLatestPerModel(ctx context.Context, monit out := make([]*service.ChannelMonitorLatest, 0) for rows.Next() { l := &service.ChannelMonitorLatest{} - var latency sql.NullInt64 - if err := rows.Scan(&l.Model, &l.Status, &latency, &l.CheckedAt); err != nil { + var latency, ping sql.NullInt64 + if err := rows.Scan(&l.Model, &l.Status, &latency, &ping, &l.CheckedAt); err != nil { return nil, fmt.Errorf("scan latest row: %w", err) } - if latency.Valid { - v := int(latency.Int64) - l.LatencyMs = &v - } + assignNullInt(&l.LatencyMs, latency) + assignNullInt(&l.PingLatencyMs, ping) out = append(out, l) } return out, rows.Err() } +// assignNullInt 把 sql.NullInt64 解包到 *int 指针目标(valid 才分配新 int)。 +// 集中实现避免 latency / ping 两处重复 if latency.Valid { v := int(...) ... } 模板。 +func assignNullInt(dst **int, n sql.NullInt64) { + if !n.Valid { + return + } + v := int(n.Int64) + *dst = &v +} + // ComputeAvailability 计算指定窗口内每个模型的可用率与平均延迟。 // "可用" = status IN (operational, degraded)。 func (r *channelMonitorRepository) ComputeAvailability(ctx context.Context, monitorID int64, windowDays int) ([]*service.ChannelMonitorAvailability, error) { @@ -338,7 +346,7 @@ func (r *channelMonitorRepository) ListLatestForMonitorIDs(ctx context.Context, } const q = ` SELECT DISTINCT ON (monitor_id, model) - monitor_id, model, status, latency_ms, checked_at + monitor_id, model, status, latency_ms, ping_latency_ms, checked_at FROM channel_monitor_histories WHERE monitor_id = ANY($1) ORDER BY monitor_id, model, checked_at DESC @@ -352,14 +360,12 @@ func (r *channelMonitorRepository) ListLatestForMonitorIDs(ctx context.Context, for rows.Next() { var monitorID int64 l := &service.ChannelMonitorLatest{} - var latency sql.NullInt64 - if err := rows.Scan(&monitorID, &l.Model, &l.Status, &latency, &l.CheckedAt); err != nil { + var latency, ping sql.NullInt64 + if err := rows.Scan(&monitorID, &l.Model, &l.Status, &latency, &ping, &l.CheckedAt); err != nil { return nil, fmt.Errorf("scan latest batch row: %w", err) } - if latency.Valid { - v := int(latency.Int64) - l.LatencyMs = &v - } + assignNullInt(&l.LatencyMs, latency) + assignNullInt(&l.PingLatencyMs, ping) out[monitorID] = append(out[monitorID], l) } if err := rows.Err(); err != nil { @@ -368,6 +374,107 @@ func (r *channelMonitorRepository) ListLatestForMonitorIDs(ctx context.Context, return out, nil } +// ListRecentHistoryForMonitors 为多个 monitor 批量取各自"指定模型"最近 N 条历史(按 checked_at DESC,最新在前)。 +// primaryModels[monitorID] 指定该监控要过滤的模型名;monitor 不在 primaryModels 中的记录不返回。 +// 通过 CTE + unnest(两个 int8/text 数组) 构造 (monitor_id, model) 白名单, +// 再用 ROW_NUMBER() OVER (PARTITION BY monitor_id) 取各自前 N 条。 +// +// 返回值:map[monitorID] -> []*ChannelMonitorHistoryEntry(不含 message,减少网络开销)。 +// 空 ids / 空 primaryModels 返回空 map,不报错。 +func (r *channelMonitorRepository) ListRecentHistoryForMonitors( + ctx context.Context, + ids []int64, + primaryModels map[int64]string, + perMonitorLimit int, +) (map[int64][]*service.ChannelMonitorHistoryEntry, error) { + out := make(map[int64][]*service.ChannelMonitorHistoryEntry, len(ids)) + pairIDs, pairModels := buildMonitorModelPairs(ids, primaryModels) + if len(pairIDs) == 0 { + return out, nil + } + perMonitorLimit = clampTimelineLimit(perMonitorLimit) + + const q = ` + WITH targets AS ( + SELECT unnest($1::bigint[]) AS monitor_id, + unnest($2::text[]) AS model + ), + ranked AS ( + SELECT h.monitor_id, + h.status, + h.latency_ms, + h.ping_latency_ms, + h.checked_at, + ROW_NUMBER() OVER (PARTITION BY h.monitor_id ORDER BY h.checked_at DESC) AS rn + FROM channel_monitor_histories h + JOIN targets t + ON t.monitor_id = h.monitor_id AND t.model = h.model + ) + SELECT monitor_id, status, latency_ms, ping_latency_ms, checked_at + FROM ranked + WHERE rn <= $3 + ORDER BY monitor_id, checked_at DESC + ` + rows, err := r.db.QueryContext(ctx, q, pq.Array(pairIDs), pq.Array(pairModels), perMonitorLimit) + if err != nil { + return nil, fmt.Errorf("query recent history batch: %w", err) + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var monitorID int64 + entry := &service.ChannelMonitorHistoryEntry{} + var latency, ping sql.NullInt64 + if err := rows.Scan(&monitorID, &entry.Status, &latency, &ping, &entry.CheckedAt); err != nil { + return nil, fmt.Errorf("scan recent history row: %w", err) + } + assignNullInt(&entry.LatencyMs, latency) + assignNullInt(&entry.PingLatencyMs, ping) + out[monitorID] = append(out[monitorID], entry) + } + if err := rows.Err(); err != nil { + return nil, err + } + return out, nil +} + +// buildMonitorModelPairs 基于 ids 过滤出有效的 (monitor_id, model) 对,model 为空时跳过。 +// 保证两个数组长度一致且一一对应,供 unnest 展开。 +func buildMonitorModelPairs(ids []int64, primaryModels map[int64]string) ([]int64, []string) { + if len(ids) == 0 || len(primaryModels) == 0 { + return nil, nil + } + pairIDs := make([]int64, 0, len(ids)) + pairModels := make([]string, 0, len(ids)) + for _, id := range ids { + model, ok := primaryModels[id] + if !ok || strings.TrimSpace(model) == "" { + continue + } + pairIDs = append(pairIDs, id) + pairModels = append(pairModels, model) + } + return pairIDs, pairModels +} + +// timelineLimit* 批量 timeline 查询的 perMonitorLimit 夹紧范围。 +// 下限 1 表示至少返回最近一条;上限 200 控制单次响应体与 SQL 内存占用(ROW_NUMBER 窗口上限)。 +const ( + timelineLimitMin = 1 + timelineLimitMax = 200 +) + +// clampTimelineLimit 把 perMonitorLimit 夹紧到 [timelineLimitMin, timelineLimitMax],避免非法值或超大查询。 +func clampTimelineLimit(n int) int { + if n < timelineLimitMin { + return timelineLimitMin + } + if n > timelineLimitMax { + return timelineLimitMax + } + return n +} + // ComputeAvailabilityForMonitors 一次性计算多个监控在某个窗口内的每模型可用率与平均延迟。 func (r *channelMonitorRepository) ComputeAvailabilityForMonitors(ctx context.Context, ids []int64, windowDays int) (map[int64][]*service.ChannelMonitorAvailability, error) { out := make(map[int64][]*service.ChannelMonitorAvailability, len(ids)) diff --git a/backend/internal/service/channel_monitor_aggregator.go b/backend/internal/service/channel_monitor_aggregator.go index 97015b40..09020f5f 100644 --- a/backend/internal/service/channel_monitor_aggregator.go +++ b/backend/internal/service/channel_monitor_aggregator.go @@ -49,7 +49,12 @@ func (s *ChannelMonitorService) BatchMonitorStatusSummary( } // ListUserView 用户只读视图:列出所有 enabled 监控的概览。 -// 使用批量聚合接口避免 N+1:1 次查 monitors,1 次查 latest(所有 monitor),1 次查 availability。 +// 使用批量聚合接口避免 N+1: +// +// 1 次查 monitors; +// 1 次批量 latest(含 ping_latency_ms); +// 1 次批量 7d availability; +// 1 次批量 timeline(主模型最近 N 条)。 func (s *ChannelMonitorService) ListUserView(ctx context.Context) ([]*UserMonitorView, error) { monitors, err := s.repo.ListEnabled(ctx) if err != nil { @@ -59,6 +64,21 @@ func (s *ChannelMonitorService) ListUserView(ctx context.Context) ([]*UserMonito return []*UserMonitorView{}, nil } + ids, primaryByID, extrasByID := collectMonitorIndexes(monitors) + summaries := s.BatchMonitorStatusSummary(ctx, ids, primaryByID, extrasByID) + latestMap := s.batchLatest(ctx, ids) + timelineMap := s.batchTimeline(ctx, ids, primaryByID) + + views := make([]*UserMonitorView, 0, len(monitors)) + for _, m := range monitors { + primaryLatest := pickLatest(latestMap[m.ID], m.PrimaryModel) + views = append(views, buildUserViewFromSummary(m, summaries[m.ID], primaryLatest, timelineMap[m.ID])) + } + return views, nil +} + +// collectMonitorIndexes 把 monitors 列表按 ID 展开为聚合查询所需的三个索引结构。 +func collectMonitorIndexes(monitors []*ChannelMonitor) ([]int64, map[int64]string, map[int64][]string) { ids := make([]int64, 0, len(monitors)) primaryByID := make(map[int64]string, len(monitors)) extrasByID := make(map[int64][]string, len(monitors)) @@ -67,14 +87,44 @@ func (s *ChannelMonitorService) ListUserView(ctx context.Context) ([]*UserMonito primaryByID[m.ID] = m.PrimaryModel extrasByID[m.ID] = m.ExtraModels } - summaries := s.BatchMonitorStatusSummary(ctx, ids, primaryByID, extrasByID) + return ids, primaryByID, extrasByID +} - views := make([]*UserMonitorView, 0, len(monitors)) - for _, m := range monitors { - summary := summaries[m.ID] - views = append(views, buildUserViewFromSummary(m, summary)) +// batchLatest 批量取 latest per model,失败仅日志(与现有 BatchMonitorStatusSummary 一致,不阻断列表渲染)。 +func (s *ChannelMonitorService) batchLatest(ctx context.Context, ids []int64) map[int64][]*ChannelMonitorLatest { + latestMap, err := s.repo.ListLatestForMonitorIDs(ctx, ids) + if err != nil { + slog.Warn("channel_monitor: user view batch latest failed", "error", err) + return map[int64][]*ChannelMonitorLatest{} } - return views, nil + return latestMap +} + +// batchTimeline 批量取每个 monitor 主模型最近 monitorTimelineMaxPoints 条历史。 +func (s *ChannelMonitorService) batchTimeline( + ctx context.Context, + ids []int64, + primaryByID map[int64]string, +) map[int64][]*ChannelMonitorHistoryEntry { + timelineMap, err := s.repo.ListRecentHistoryForMonitors(ctx, ids, primaryByID, monitorTimelineMaxPoints) + if err != nil { + slog.Warn("channel_monitor: user view batch timeline failed", "error", err) + return map[int64][]*ChannelMonitorHistoryEntry{} + } + return timelineMap +} + +// pickLatest 从 latest 切片中挑出指定 model 对应项,未命中返回 nil。 +func pickLatest(rows []*ChannelMonitorLatest, model string) *ChannelMonitorLatest { + if model == "" { + return nil + } + for _, r := range rows { + if r.Model == model { + return r + } + } + return nil } // GetUserDetail 用户只读视图:单个监控详情(每个模型 7d/15d/30d 可用率与平均延迟)。 @@ -170,9 +220,15 @@ func buildStatusSummary( return summary } -// buildUserViewFromSummary 用预聚合好的 MonitorStatusSummary 装填 UserMonitorView(无 IO)。 -func buildUserViewFromSummary(m *ChannelMonitor, summary MonitorStatusSummary) *UserMonitorView { - return &UserMonitorView{ +// buildUserViewFromSummary 用预聚合好的 MonitorStatusSummary + 主模型 latest + timeline 装填 UserMonitorView(无 IO)。 +// primaryLatest 可能为 nil(该监控尚无历史);timelineEntries 可能为空。 +func buildUserViewFromSummary( + m *ChannelMonitor, + summary MonitorStatusSummary, + primaryLatest *ChannelMonitorLatest, + timelineEntries []*ChannelMonitorHistoryEntry, +) *UserMonitorView { + view := &UserMonitorView{ ID: m.ID, Name: m.Name, Provider: m.Provider, @@ -182,7 +238,26 @@ func buildUserViewFromSummary(m *ChannelMonitor, summary MonitorStatusSummary) * PrimaryLatencyMs: summary.PrimaryLatencyMs, Availability7d: summary.Availability7d, ExtraModels: summary.ExtraModels, + Timeline: buildTimelinePoints(timelineEntries), } + if primaryLatest != nil { + view.PrimaryPingLatencyMs = primaryLatest.PingLatencyMs + } + return view +} + +// buildTimelinePoints 把 history entry 裁剪为 timeline 点(去除 message/ID/Model,减小响应体)。 +func buildTimelinePoints(entries []*ChannelMonitorHistoryEntry) []UserMonitorTimelinePoint { + out := make([]UserMonitorTimelinePoint, 0, len(entries)) + for _, e := range entries { + out = append(out, UserMonitorTimelinePoint{ + Status: e.Status, + LatencyMs: e.LatencyMs, + PingLatencyMs: e.PingLatencyMs, + CheckedAt: e.CheckedAt, + }) + } + return out } // mergeModelDetails 合并 latest + availability 三个窗口为 ModelDetail 列表。 diff --git a/backend/internal/service/channel_monitor_const.go b/backend/internal/service/channel_monitor_const.go index b4c02bcb..7255e4be 100644 --- a/backend/internal/service/channel_monitor_const.go +++ b/backend/internal/service/channel_monitor_const.go @@ -65,6 +65,9 @@ const ( // MonitorHistoryMaxLimit 历史查询最大返回条数(handler 层共享)。 MonitorHistoryMaxLimit = 1000 + // monitorTimelineMaxPoints 用户视图 timeline 每个监控最多返回的历史点数。 + monitorTimelineMaxPoints = 60 + // monitorEndpointResolveTimeout validateEndpoint 解析 hostname 的最长耗时。 monitorEndpointResolveTimeout = 5 * time.Second diff --git a/backend/internal/service/channel_monitor_service.go b/backend/internal/service/channel_monitor_service.go index b179e50c..957ace15 100644 --- a/backend/internal/service/channel_monitor_service.go +++ b/backend/internal/service/channel_monitor_service.go @@ -38,6 +38,9 @@ type ChannelMonitorRepository interface { // 批量聚合(admin/user list 用,避免 N+1) ListLatestForMonitorIDs(ctx context.Context, ids []int64) (map[int64][]*ChannelMonitorLatest, error) ComputeAvailabilityForMonitors(ctx context.Context, ids []int64, windowDays int) (map[int64][]*ChannelMonitorAvailability, error) + // ListRecentHistoryForMonitors 批量取多个 monitor 各自主模型(primaryModels[monitorID])最近 perMonitorLimit 条历史。 + // 返回的 entry 已按 checked_at DESC 排序(最新在前),不含 message 字段。 + ListRecentHistoryForMonitors(ctx context.Context, ids []int64, primaryModels map[int64]string, perMonitorLimit int) (map[int64][]*ChannelMonitorHistoryEntry, error) } // ChannelMonitorService 渠道监控管理服务。 diff --git a/backend/internal/service/channel_monitor_types.go b/backend/internal/service/channel_monitor_types.go index 4b34d8af..739c82fb 100644 --- a/backend/internal/service/channel_monitor_types.go +++ b/backend/internal/service/channel_monitor_types.go @@ -72,15 +72,25 @@ type CheckResult struct { // UserMonitorView 用户只读视图:监控概览(含主模型最近状态 + 7d 可用率 + 附加模型最近状态)。 type UserMonitorView struct { - ID int64 - Name string - Provider string - GroupName string - PrimaryModel string - PrimaryStatus string - PrimaryLatencyMs *int - Availability7d float64 // 0-100 - ExtraModels []ExtraModelStatus + ID int64 + Name string + Provider string + GroupName string + PrimaryModel string + PrimaryStatus string + PrimaryLatencyMs *int + PrimaryPingLatencyMs *int // 主模型最近一次 ping 延迟 + Availability7d float64 // 0-100 + ExtraModels []ExtraModelStatus + Timeline []UserMonitorTimelinePoint // 主模型最近 N 个历史点(按 checked_at DESC,最新在前) +} + +// UserMonitorTimelinePoint 用户视图 timeline 单点数据(去除 message 以减小响应体)。 +type UserMonitorTimelinePoint struct { + Status string `json:"status"` + LatencyMs *int `json:"latency_ms"` + PingLatencyMs *int `json:"ping_latency_ms"` + CheckedAt time.Time `json:"checked_at"` } // ExtraModelStatus 附加模型最近一次状态。 @@ -134,10 +144,11 @@ type ChannelMonitorHistoryEntry struct { // ChannelMonitorLatest 最近一次检测的简明信息(用于 UserMonitorView 聚合)。 type ChannelMonitorLatest struct { - Model string - Status string - LatencyMs *int - CheckedAt time.Time + Model string + Status string + LatencyMs *int + PingLatencyMs *int + CheckedAt time.Time } // ChannelMonitorAvailability 单个模型在某窗口内的可用率与平均延迟(用于 UserMonitorDetail 聚合)。 diff --git a/frontend/src/api/channelMonitor.ts b/frontend/src/api/channelMonitor.ts index c5481636..38dd0c99 100644 --- a/frontend/src/api/channelMonitor.ts +++ b/frontend/src/api/channelMonitor.ts @@ -14,6 +14,13 @@ export interface UserMonitorExtraModel { latency_ms: number | null } +export interface MonitorTimelinePoint { + status: MonitorStatus + latency_ms: number | null + ping_latency_ms: number | null + checked_at: string +} + export interface UserMonitorView { id: number name: string @@ -22,8 +29,10 @@ export interface UserMonitorView { primary_model: string primary_status: MonitorStatus primary_latency_ms: number | null + primary_ping_latency_ms: number | null availability_7d: number extra_models: UserMonitorExtraModel[] + timeline: MonitorTimelinePoint[] } export interface UserMonitorListResponse { diff --git a/frontend/src/components/user/MonitorPrimaryModelCell.vue b/frontend/src/components/user/MonitorPrimaryModelCell.vue deleted file mode 100644 index 32620b2a..00000000 --- a/frontend/src/components/user/MonitorPrimaryModelCell.vue +++ /dev/null @@ -1,71 +0,0 @@ - - - diff --git a/frontend/src/components/user/monitor/MonitorAvailabilityRow.vue b/frontend/src/components/user/monitor/MonitorAvailabilityRow.vue new file mode 100644 index 00000000..34420c9d --- /dev/null +++ b/frontend/src/components/user/monitor/MonitorAvailabilityRow.vue @@ -0,0 +1,49 @@ + + + diff --git a/frontend/src/components/user/monitor/MonitorCard.vue b/frontend/src/components/user/monitor/MonitorCard.vue new file mode 100644 index 00000000..33742c6d --- /dev/null +++ b/frontend/src/components/user/monitor/MonitorCard.vue @@ -0,0 +1,128 @@ + + + diff --git a/frontend/src/components/user/monitor/MonitorCardGrid.vue b/frontend/src/components/user/monitor/MonitorCardGrid.vue new file mode 100644 index 00000000..c7d24c01 --- /dev/null +++ b/frontend/src/components/user/monitor/MonitorCardGrid.vue @@ -0,0 +1,81 @@ + + + diff --git a/frontend/src/components/user/monitor/MonitorHero.vue b/frontend/src/components/user/monitor/MonitorHero.vue new file mode 100644 index 00000000..be5a96b8 --- /dev/null +++ b/frontend/src/components/user/monitor/MonitorHero.vue @@ -0,0 +1,133 @@ + + + diff --git a/frontend/src/components/user/monitor/MonitorMetricPair.vue b/frontend/src/components/user/monitor/MonitorMetricPair.vue new file mode 100644 index 00000000..0f3fd3dc --- /dev/null +++ b/frontend/src/components/user/monitor/MonitorMetricPair.vue @@ -0,0 +1,45 @@ + + + diff --git a/frontend/src/components/user/monitor/MonitorTimeline.vue b/frontend/src/components/user/monitor/MonitorTimeline.vue new file mode 100644 index 00000000..b4d0c151 --- /dev/null +++ b/frontend/src/components/user/monitor/MonitorTimeline.vue @@ -0,0 +1,113 @@ + + + diff --git a/frontend/src/components/user/monitor/ProviderIcon.vue b/frontend/src/components/user/monitor/ProviderIcon.vue new file mode 100644 index 00000000..20456a2c --- /dev/null +++ b/frontend/src/components/user/monitor/ProviderIcon.vue @@ -0,0 +1,71 @@ + + + diff --git a/frontend/src/composables/useChannelMonitorFormat.ts b/frontend/src/composables/useChannelMonitorFormat.ts index fbb310fa..7ffdaa42 100644 --- a/frontend/src/composables/useChannelMonitorFormat.ts +++ b/frontend/src/composables/useChannelMonitorFormat.ts @@ -4,6 +4,7 @@ * Centralises: * - status / provider label + badge class lookups * - latency / availability / percent number formatting + * - dashboard-style helpers (HSL for availability, provider gradient, relative time) * * i18n keys live under `monitorCommon.*` so admin and user views share the * same translation source. @@ -23,6 +24,11 @@ import { const NEUTRAL_BADGE = 'bg-gray-100 text-gray-800 dark:bg-dark-700 dark:text-gray-300' +/** Availability HSL hue multiplier: 0%=red(0) / 50%=yellow(60) / 100%=green(120). */ +const HSL_HUE_PER_PERCENT = 1.2 +const HSL_SATURATION = 72 +const HSL_LIGHTNESS = 42 + export interface AvailabilityRow { primary_status: MonitorStatus | '' availability_7d: number | null | undefined @@ -39,11 +45,11 @@ export function useChannelMonitorFormat() { function statusBadgeClass(s: MonitorStatus | ''): string { switch (s) { case STATUS_OPERATIONAL: - return 'bg-green-100 text-green-800 dark:bg-green-900/30 dark:text-green-300' + return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-500/15 dark:text-emerald-300' case STATUS_DEGRADED: - return 'bg-yellow-100 text-yellow-800 dark:bg-yellow-900/30 dark:text-yellow-300' + return 'bg-amber-100 text-amber-700 dark:bg-amber-500/15 dark:text-amber-300' case STATUS_FAILED: - return 'bg-red-100 text-red-800 dark:bg-red-900/30 dark:text-red-300' + return 'bg-red-100 text-red-700 dark:bg-red-500/15 dark:text-red-300' case STATUS_ERROR: default: return NEUTRAL_BADGE @@ -60,11 +66,11 @@ export function useChannelMonitorFormat() { function providerBadgeClass(p: Provider | string): string { switch (p) { case PROVIDER_OPENAI: - return 'bg-green-100 text-green-800 dark:bg-green-900/30 dark:text-green-300' + return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-500/15 dark:text-emerald-300' case PROVIDER_ANTHROPIC: - return 'bg-orange-100 text-orange-800 dark:bg-orange-900/30 dark:text-orange-300' + return 'bg-orange-100 text-orange-700 dark:bg-orange-500/15 dark:text-orange-300' case PROVIDER_GEMINI: - return 'bg-blue-100 text-blue-800 dark:bg-blue-900/30 dark:text-blue-300' + return 'bg-sky-100 text-sky-700 dark:bg-sky-500/15 dark:text-sky-300' default: return NEUTRAL_BADGE } @@ -85,6 +91,20 @@ export function useChannelMonitorFormat() { return formatPercent(row.availability_7d) } + function formatRelativeTime(iso: string | null | undefined): string { + if (!iso) return t('monitorCommon.latencyEmpty') + const ts = Date.parse(iso) + if (Number.isNaN(ts)) return t('monitorCommon.latencyEmpty') + const diffSec = Math.max(0, Math.floor((Date.now() - ts) / 1000)) + if (diffSec < 60) return t('monitorCommon.relativeSecondsAgo', { n: diffSec }) + const diffMin = Math.floor(diffSec / 60) + if (diffMin < 60) return t('monitorCommon.relativeMinutesAgo', { n: diffMin }) + const diffHour = Math.floor(diffMin / 60) + if (diffHour < 24) return t('monitorCommon.relativeHoursAgo', { n: diffHour }) + const diffDay = Math.floor(diffHour / 24) + return t('monitorCommon.relativeDaysAgo', { n: diffDay }) + } + return { statusLabel, statusBadgeClass, @@ -93,5 +113,33 @@ export function useChannelMonitorFormat() { formatLatency, formatPercent, formatAvailability, + formatRelativeTime, + } +} + +/** + * Map availability percent to an HSL colour (red -> yellow -> green). + * Returns undefined for null/NaN so callers can fall back to a neutral colour. + */ +export function hslForPct(pct: number | null | undefined): string | undefined { + if (pct === null || pct === undefined || Number.isNaN(pct)) return undefined + const clamped = Math.max(0, Math.min(100, pct)) + const hue = clamped * HSL_HUE_PER_PERCENT + return `hsl(${hue} ${HSL_SATURATION}% ${HSL_LIGHTNESS}%)` +} + +/** + * Tailwind gradient class for the provider icon tile background. + */ +export function providerGradient(provider: string): string { + switch (provider) { + case PROVIDER_OPENAI: + return 'bg-gradient-to-br from-emerald-50 to-emerald-100 dark:from-emerald-500/10 dark:to-emerald-500/20' + case PROVIDER_ANTHROPIC: + return 'bg-gradient-to-br from-orange-50 to-amber-100 dark:from-orange-500/10 dark:to-amber-500/20' + case PROVIDER_GEMINI: + return 'bg-gradient-to-br from-sky-50 to-indigo-100 dark:from-sky-500/10 dark:to-indigo-500/20' + default: + return 'bg-gradient-to-br from-gray-100 to-gray-200 dark:from-dark-700 dark:to-dark-600' } } diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 32fbce19..b95c8b44 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -867,7 +867,22 @@ export default { }, extraModelsHeader: 'Extra Models', extraModelsEmpty: 'No extra models', - latencyEmpty: '-' + latencyEmpty: '-', + availabilityPrefix: 'Availability', + dialogLatency: 'Dialog Latency', + endpointPing: 'Endpoint PING', + history60pts: 'HISTORY ({n} PTS)', + nextUpdateIn: 'NEXT UPDATE IN {n}s', + past: 'PAST', + now: 'NOW', + maintenancePaused: 'Maintenance · timeline paused', + extraModelsCount: '+ {n} models', + pollEvery: '{n}s polling', + updatedAt: 'Updated {time}', + relativeSecondsAgo: '{n}s ago', + relativeMinutesAgo: '{n}m ago', + relativeHoursAgo: '{n}h ago', + relativeDaysAgo: '{n}d ago' }, // Channel Status (user-facing read-only view) @@ -880,6 +895,22 @@ export default { detailLoadError: 'Failed to load channel detail', detailTitle: 'Channel Detail', closeDetail: 'Close', + hero: { + breadcrumb: 'CHANNEL · STATUS', + title: 'INTELLIGENCE MONITOR', + subtitleZh: 'Real-time tracking of availability, latency and status for leading AI endpoints.', + subtitleEn: 'Advanced performance metrics for next-gen intelligence.' + }, + windowTab: { + '7d': '7 days', + '15d': '15 days', + '30d': '30 days' + }, + overall: { + operational: 'OPERATIONAL', + degraded: 'DEGRADED', + unavailable: 'UNAVAILABLE' + }, columns: { name: 'Name', provider: 'Provider', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index dd3af363..54bc03c5 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -871,7 +871,22 @@ export default { }, extraModelsHeader: '附加模型', extraModelsEmpty: '无附加模型', - latencyEmpty: '-' + latencyEmpty: '-', + availabilityPrefix: '可用性', + dialogLatency: '对话延迟', + endpointPing: '端点 PING', + history60pts: '近 {n} 次记录', + nextUpdateIn: '{n}s 后刷新', + past: 'PAST', + now: 'NOW', + maintenancePaused: '维护中 · 已暂停时间线采集', + extraModelsCount: '+ {n} 模型', + pollEvery: '{n}s 轮询', + updatedAt: '更新于 {time}', + relativeSecondsAgo: '{n} 秒前', + relativeMinutesAgo: '{n} 分钟前', + relativeHoursAgo: '{n} 小时前', + relativeDaysAgo: '{n} 天前' }, // Channel Status (user-facing read-only view) @@ -884,6 +899,22 @@ export default { detailLoadError: '加载渠道详情失败', detailTitle: '渠道详情', closeDetail: '关闭', + hero: { + breadcrumb: '渠道 · 状态', + title: 'INTELLIGENCE MONITOR', + subtitleZh: '实时追踪各大 AI 模型对话接口的可用性、延迟与官方服务状态。', + subtitleEn: 'Advanced performance metrics for next-gen intelligence.' + }, + windowTab: { + '7d': '7 天', + '15d': '15 天', + '30d': '30 天' + }, + overall: { + operational: 'OPERATIONAL', + degraded: 'DEGRADED', + unavailable: 'UNAVAILABLE' + }, columns: { name: '名称', provider: '供应商', diff --git a/frontend/src/views/user/ChannelStatusView.vue b/frontend/src/views/user/ChannelStatusView.vue index 9f5fe8d1..af427cca 100644 --- a/frontend/src/views/user/ChannelStatusView.vue +++ b/frontend/src/views/user/ChannelStatusView.vue @@ -1,93 +1,23 @@ @@ -174,6 +175,7 @@

{{ errorMessage }}

+

{{ errorHintMessage }}

@@ -281,6 +283,7 @@ import SubscriptionPlanCard from '@/components/payment/SubscriptionPlanCard.vue' import PaymentStatusPanel from '@/components/payment/PaymentStatusPanel.vue' import Icon from '@/components/icons/Icon.vue' import type { PaymentMethodOption } from '@/components/payment/PaymentMethodSelector.vue' +import { describePaymentScenarioError } from './paymentUx' const { t } = useI18n() const route = useRoute() @@ -301,6 +304,7 @@ function getDaysRemaining(expiresAt: string): number { const loading = ref(true) const submitting = ref(false) const errorMessage = ref('') +const errorHintMessage = ref('') const activeTab = ref<'recharge' | 'subscription'>('recharge') const amount = ref(null) const selectedMethod = ref('') @@ -619,6 +623,7 @@ async function confirmSubscribe() { async function createOrder(orderAmount: number, orderType: OrderType, planId?: number, options: CreateOrderOptions = {}) { submitting.value = true errorMessage.value = '' + errorHintMessage.value = '' try { const requestType = normalizeVisibleMethod(options.paymentType || selectedMethod.value) || options.paymentType || selectedMethod.value const payload = buildCreateOrderPayload({ @@ -668,8 +673,7 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n } if (decision.kind === 'unhandled') { - errorMessage.value = t('payment.result.failed') - appStore.showError(errorMessage.value) + applyScenarioError({ reason: 'UNHANDLED_PAYMENT_SCENARIO' }, visibleMethod) return } @@ -691,7 +695,7 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n if (errMsg.includes('cancel')) { appStore.showInfo(t('payment.qr.cancelled')) } else if (errMsg && !errMsg.includes('ok')) { - appStore.showError(t('payment.result.failed')) + applyScenarioError({ reason: 'WECHAT_JSAPI_FAILED', message: errMsg }, visibleMethod) } return } @@ -707,10 +711,16 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n if (apiErr.reason === 'TOO_MANY_PENDING') { const metadata = apiErr.metadata as Record | undefined errorMessage.value = t('payment.errors.tooManyPending', { max: metadata?.max || '' }) + errorHintMessage.value = '' } else if (apiErr.reason === 'CANCEL_RATE_LIMITED') { errorMessage.value = t('payment.errors.cancelRateLimited') + errorHintMessage.value = '' } else { - errorMessage.value = extractApiErrorMessage(err, t('payment.result.failed')) + applyScenarioError(err, normalizeVisibleMethod(options.paymentType || selectedMethod.value) || selectedMethod.value) + if (!errorMessage.value) { + errorMessage.value = extractApiErrorMessage(err, t('payment.result.failed')) + errorHintMessage.value = '' + } } appStore.showError(errorMessage.value) } finally { @@ -718,6 +728,21 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n } } +function applyScenarioError(err: unknown, paymentMethod: string) { + const descriptor = describePaymentScenarioError(err, { + paymentMethod, + isMobile: isMobileDevice(), + isWechatBrowser: typeof window !== 'undefined' && /MicroMessenger/i.test(window.navigator.userAgent), + }) + if (!descriptor) { + errorMessage.value = '' + errorHintMessage.value = '' + return + } + errorMessage.value = t(descriptor.messageKey) + errorHintMessage.value = descriptor.hintKey ? t(descriptor.hintKey) : '' +} + async function resumeWechatPaymentFromQuery() { const openid = readRouteQueryValue(route.query.openid) if (readRouteQueryValue(route.query.wechat_resume) !== '1' || !openid) { diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts index d23a60d9..b1caa526 100644 --- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts +++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts @@ -155,4 +155,29 @@ describe('PaymentResultView', () => { expect(wrapper.text()).toContain('payment.result.success') expect(verifyOrderPublic).not.toHaveBeenCalled() }) + + it('normalizes aliased payment methods before rendering the label', async () => { + routeState.query = { + resume_token: 'resume-88', + } + resolveOrderPublicByResumeToken.mockResolvedValueOnce({ + data: { + ...orderFactory('PAID'), + payment_type: 'alipay_direct', + }, + }) + + const wrapper = mount(PaymentResultView, { + global: { + stubs: { + OrderStatusBadge: true, + }, + }, + }) + + await flushPromises() + + expect(wrapper.text()).toContain('payment.methods.alipay') + expect(wrapper.text()).not.toContain('payment.methods.alipay_direct') + }) }) diff --git a/frontend/src/views/user/__tests__/paymentUx.spec.ts b/frontend/src/views/user/__tests__/paymentUx.spec.ts new file mode 100644 index 00000000..6cf105f2 --- /dev/null +++ b/frontend/src/views/user/__tests__/paymentUx.spec.ts @@ -0,0 +1,49 @@ +import { describe, expect, it } from 'vitest' +import { + describePaymentScenarioError, + normalizePaymentMethodForDisplay, +} from '../paymentUx' + +describe('normalizePaymentMethodForDisplay', () => { + it('collapses visible payment aliases to canonical method ids', () => { + expect(normalizePaymentMethodForDisplay(' alipay_direct ')).toBe('alipay') + expect(normalizePaymentMethodForDisplay('wxpay_direct')).toBe('wxpay') + expect(normalizePaymentMethodForDisplay('wechat_pay')).toBe('wxpay') + }) + + it('leaves non-aliased methods untouched', () => { + expect(normalizePaymentMethodForDisplay('stripe')).toBe('stripe') + }) +}) + +describe('describePaymentScenarioError', () => { + it('maps WeChat H5 authorization errors to explicit in-app guidance', () => { + expect(describePaymentScenarioError( + { reason: 'WECHAT_H5_NOT_AUTHORIZED' }, + { paymentMethod: 'wxpay', isMobile: true, isWechatBrowser: false }, + )).toEqual({ + messageKey: 'payment.errors.wechatH5NotAuthorized', + hintKey: 'payment.errors.wechatOpenInWeChatHint', + }) + }) + + it('maps missing WeixinJSBridge to a JSAPI-specific prompt', () => { + expect(describePaymentScenarioError( + new Error('WeixinJSBridge is unavailable'), + { paymentMethod: 'wxpay', isMobile: true, isWechatBrowser: true }, + )).toEqual({ + messageKey: 'payment.errors.wechatJsapiUnavailable', + hintKey: 'payment.errors.wechatOpenInWeChatHint', + }) + }) + + it('maps generic desktop Alipay failures to QR guidance', () => { + expect(describePaymentScenarioError( + { reason: 'PAYMENT_GATEWAY_ERROR' }, + { paymentMethod: 'alipay', isMobile: false, isWechatBrowser: false }, + )).toEqual({ + messageKey: 'payment.errors.alipayDesktopUnavailable', + hintKey: 'payment.errors.alipayDesktopQrHint', + }) + }) +}) diff --git a/frontend/src/views/user/paymentUx.ts b/frontend/src/views/user/paymentUx.ts new file mode 100644 index 00000000..443529a7 --- /dev/null +++ b/frontend/src/views/user/paymentUx.ts @@ -0,0 +1,105 @@ +import { normalizeVisibleMethod } from '@/components/payment/paymentFlow' +import { extractApiErrorCode } from '@/utils/apiError' + +const DISPLAY_METHOD_ALIASES: Record = { + wechat: 'wxpay', + wechat_pay: 'wxpay', +} + +export interface PaymentScenarioContext { + paymentMethod: string + isMobile: boolean + isWechatBrowser: boolean +} + +export interface PaymentScenarioErrorDescriptor { + messageKey: string + hintKey?: string +} + +export function normalizePaymentMethodForDisplay(paymentType: string): string { + const trimmed = paymentType.trim().toLowerCase() + const visibleMethod = normalizeVisibleMethod(trimmed) + if (visibleMethod) return visibleMethod + return DISPLAY_METHOD_ALIASES[trimmed] ?? trimmed +} + +export function paymentMethodI18nKey(paymentType: string): string { + return `payment.methods.${normalizePaymentMethodForDisplay(paymentType)}` +} + +function defaultWechatHint(context: PaymentScenarioContext): string { + if (!context.isMobile) return 'payment.errors.wechatScanOnDesktopHint' + return 'payment.errors.wechatOpenInWeChatHint' +} + +function defaultAlipayHint(context: PaymentScenarioContext): string { + if (context.isMobile) return 'payment.errors.alipayMobileOpenHint' + return 'payment.errors.alipayDesktopQrHint' +} + +export function describePaymentScenarioError( + error: unknown, + context: PaymentScenarioContext, +): PaymentScenarioErrorDescriptor | null { + const method = normalizePaymentMethodForDisplay(context.paymentMethod) + const code = extractApiErrorCode(error) + const message = error instanceof Error + ? error.message + : (typeof error === 'object' && error && 'message' in error && typeof error.message === 'string' + ? error.message + : String(error || '')) + const normalizedMessage = message.toLowerCase() + + if (method === 'wxpay') { + if (code === 'WECHAT_H5_NOT_AUTHORIZED') { + return { + messageKey: 'payment.errors.wechatH5NotAuthorized', + hintKey: defaultWechatHint(context), + } + } + if (code === 'WECHAT_PAYMENT_MP_NOT_CONFIGURED') { + return { + messageKey: 'payment.errors.wechatPaymentMpNotConfigured', + hintKey: context.isWechatBrowser + ? 'payment.errors.wechatSwitchBrowserHint' + : defaultWechatHint(context), + } + } + if (code === 'NO_AVAILABLE_INSTANCE') { + return { + messageKey: 'payment.errors.wechatUnavailable', + hintKey: defaultWechatHint(context), + } + } + if (code === 'WECHAT_JSAPI_FAILED' || normalizedMessage.includes('get_brand_wcpay_request:fail')) { + return { + messageKey: 'payment.errors.wechatJsapiFailed', + hintKey: defaultWechatHint(context), + } + } + if (normalizedMessage.includes('weixinjsbridge is unavailable')) { + return { + messageKey: 'payment.errors.wechatJsapiUnavailable', + hintKey: 'payment.errors.wechatOpenInWeChatHint', + } + } + if (code === 'PAYMENT_GATEWAY_ERROR' || code === 'UNHANDLED_PAYMENT_SCENARIO') { + return { + messageKey: 'payment.errors.wechatUnavailable', + hintKey: defaultWechatHint(context), + } + } + } + + if (method === 'alipay' && (code === 'PAYMENT_GATEWAY_ERROR' || code === 'UNHANDLED_PAYMENT_SCENARIO')) { + return { + messageKey: context.isMobile + ? 'payment.errors.alipayMobileUnavailable' + : 'payment.errors.alipayDesktopUnavailable', + hintKey: defaultAlipayHint(context), + } + } + + return null +} From 9e84e2fd2bed765d38ca4a438fbbf1739950ae32 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Tue, 21 Apr 2026 00:05:17 +0800 Subject: [PATCH 061/326] fix: persist admin payment visibility and scheduler settings --- .../internal/handler/admin/setting_handler.go | 64 ++++++++++++++++ ...tting_handler_auth_source_defaults_test.go | 74 +++++++++++++++++++ backend/internal/handler/dto/settings.go | 9 +++ backend/internal/server/api_contract_test.go | 63 +++++++++++++++- .../service/admin_service_apikey_test.go | 6 ++ backend/internal/service/setting_service.go | 54 +++++++++++++- .../service/setting_service_update_test.go | 31 ++++++++ backend/internal/service/settings_view.go | 9 +++ backend/internal/service/user_service_test.go | 6 ++ 9 files changed, 311 insertions(+), 5 deletions(-) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index fe5c7928..e5681208 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -180,6 +180,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { EnableMetadataPassthrough: settings.EnableMetadataPassthrough, EnableCCHSigning: settings.EnableCCHSigning, WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled, + PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource, + PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource, + PaymentVisibleMethodAlipayEnabled: settings.PaymentVisibleMethodAlipayEnabled, + PaymentVisibleMethodWxpayEnabled: settings.PaymentVisibleMethodWxpayEnabled, + OpenAIAdvancedSchedulerEnabled: settings.OpenAIAdvancedSchedulerEnabled, BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled, BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL, @@ -338,6 +343,15 @@ type UpdateSettingsRequest struct { EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"` EnableCCHSigning *bool `json:"enable_cch_signing"` + // Payment visible method routing + PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"` + PaymentVisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"` + PaymentVisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"` + PaymentVisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"` + + // OpenAI account scheduling + OpenAIAdvancedSchedulerEnabled *bool `json:"openai_advanced_scheduler_enabled"` + // Balance low notification BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"` BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"` @@ -935,6 +949,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.EnableCCHSigning }(), + PaymentVisibleMethodAlipaySource: func() string { + if req.PaymentVisibleMethodAlipaySource != nil { + return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource) + } + return previousSettings.PaymentVisibleMethodAlipaySource + }(), + PaymentVisibleMethodWxpaySource: func() string { + if req.PaymentVisibleMethodWxpaySource != nil { + return strings.TrimSpace(*req.PaymentVisibleMethodWxpaySource) + } + return previousSettings.PaymentVisibleMethodWxpaySource + }(), + PaymentVisibleMethodAlipayEnabled: func() bool { + if req.PaymentVisibleMethodAlipayEnabled != nil { + return *req.PaymentVisibleMethodAlipayEnabled + } + return previousSettings.PaymentVisibleMethodAlipayEnabled + }(), + PaymentVisibleMethodWxpayEnabled: func() bool { + if req.PaymentVisibleMethodWxpayEnabled != nil { + return *req.PaymentVisibleMethodWxpayEnabled + } + return previousSettings.PaymentVisibleMethodWxpayEnabled + }(), + OpenAIAdvancedSchedulerEnabled: func() bool { + if req.OpenAIAdvancedSchedulerEnabled != nil { + return *req.OpenAIAdvancedSchedulerEnabled + } + return previousSettings.OpenAIAdvancedSchedulerEnabled + }(), BalanceLowNotifyEnabled: func() bool { if req.BalanceLowNotifyEnabled != nil { return *req.BalanceLowNotifyEnabled @@ -1153,6 +1197,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification, EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough, EnableCCHSigning: updatedSettings.EnableCCHSigning, + PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource, + PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource, + PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled, + PaymentVisibleMethodWxpayEnabled: updatedSettings.PaymentVisibleMethodWxpayEnabled, + OpenAIAdvancedSchedulerEnabled: updatedSettings.OpenAIAdvancedSchedulerEnabled, BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled, BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold, BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL, @@ -1455,6 +1504,21 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.EnableCCHSigning != after.EnableCCHSigning { changed = append(changed, "enable_cch_signing") } + if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource { + changed = append(changed, "payment_visible_method_alipay_source") + } + if before.PaymentVisibleMethodWxpaySource != after.PaymentVisibleMethodWxpaySource { + changed = append(changed, "payment_visible_method_wxpay_source") + } + if before.PaymentVisibleMethodAlipayEnabled != after.PaymentVisibleMethodAlipayEnabled { + changed = append(changed, "payment_visible_method_alipay_enabled") + } + if before.PaymentVisibleMethodWxpayEnabled != after.PaymentVisibleMethodWxpayEnabled { + changed = append(changed, "payment_visible_method_wxpay_enabled") + } + if before.OpenAIAdvancedSchedulerEnabled != after.OpenAIAdvancedSchedulerEnabled { + changed = append(changed, "openai_advanced_scheduler_enabled") + } // Balance & quota notification if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled { changed = append(changed, "balance_low_notify_enabled") diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go index b26fa447..bf51fc68 100644 --- a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go +++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go @@ -147,3 +147,77 @@ func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *tes require.Equal(t, float64(8), data["auth_source_default_email_concurrency"]) require.Equal(t, true, data["force_email_on_third_party_signup"]) } + +func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyPromoCodeEnabled: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "promo_code_enabled": true, + "payment_visible_method_alipay_source": "easypay", + "payment_visible_method_wxpay_source": "wxpay", + "payment_visible_method_alipay_enabled": true, + "payment_visible_method_wxpay_enabled": false, + "openai_advanced_scheduler_enabled": true, + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, repo.values[service.SettingPaymentVisibleMethodAlipaySource]) + require.Equal(t, service.VisibleMethodSourceOfficialWechat, repo.values[service.SettingPaymentVisibleMethodWxpaySource]) + require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled]) + require.Equal(t, "false", repo.values[service.SettingPaymentVisibleMethodWxpayEnabled]) + require.Equal(t, "true", repo.values["openai_advanced_scheduler_enabled"]) + + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, data["payment_visible_method_alipay_source"]) + require.Equal(t, service.VisibleMethodSourceOfficialWechat, data["payment_visible_method_wxpay_source"]) + require.Equal(t, true, data["payment_visible_method_alipay_enabled"]) + require.Equal(t, false, data["payment_visible_method_wxpay_enabled"]) + require.Equal(t, true, data["openai_advanced_scheduler_enabled"]) +} + +func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyPromoCodeEnabled: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "promo_code_enabled": true, + "payment_visible_method_alipay_source": "bogus", + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.NotContains(t, repo.values, service.SettingPaymentVisibleMethodAlipaySource) +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 637e317b..cc3f8496 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -127,6 +127,15 @@ type SystemSettings struct { // Web Search Emulation WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"` + // Payment visible method routing + PaymentVisibleMethodAlipaySource string `json:"payment_visible_method_alipay_source"` + PaymentVisibleMethodWxpaySource string `json:"payment_visible_method_wxpay_source"` + PaymentVisibleMethodAlipayEnabled bool `json:"payment_visible_method_alipay_enabled"` + PaymentVisibleMethodWxpayEnabled bool `json:"payment_visible_method_wxpay_enabled"` + + // OpenAI account scheduling + OpenAIAdvancedSchedulerEnabled bool `json:"openai_advanced_scheduler_enabled"` + // Payment configuration PaymentEnabled bool `json:"payment_enabled"` PaymentMinAmount float64 `json:"payment_min_amount"` diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index e903898f..533f2dac 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -500,10 +500,15 @@ func TestAPIContracts(t *testing.T) { service.SettingKeyTableDefaultPageSize: "20", service.SettingKeyTablePageSizeOptions: "[10,20,50,100]", - service.SettingKeyOpsMonitoringEnabled: "false", - service.SettingKeyOpsRealtimeMonitoringEnabled: "true", - service.SettingKeyOpsQueryModeDefault: "auto", - service.SettingKeyOpsMetricsIntervalSeconds: "60", + service.SettingKeyOpsMonitoringEnabled: "false", + service.SettingKeyOpsRealtimeMonitoringEnabled: "true", + service.SettingKeyOpsQueryModeDefault: "auto", + service.SettingKeyOpsMetricsIntervalSeconds: "60", + service.SettingPaymentVisibleMethodAlipaySource: service.VisibleMethodSourceEasyPayAlipay, + service.SettingPaymentVisibleMethodWxpaySource: service.VisibleMethodSourceOfficialWechat, + service.SettingPaymentVisibleMethodAlipayEnabled: "true", + service.SettingPaymentVisibleMethodWxpayEnabled: "false", + "openai_advanced_scheduler_enabled": "true", }) }, method: http.MethodGet, @@ -567,6 +572,27 @@ func TestAPIContracts(t *testing.T) { "api_base_url": "https://api.example.com", "contact_info": "support", "doc_url": "https://docs.example.com", + "auth_source_default_email_balance": 0, + "auth_source_default_email_concurrency": 5, + "auth_source_default_email_subscriptions": [], + "auth_source_default_email_grant_on_signup": false, + "auth_source_default_email_grant_on_first_bind": false, + "auth_source_default_linuxdo_balance": 0, + "auth_source_default_linuxdo_concurrency": 5, + "auth_source_default_linuxdo_subscriptions": [], + "auth_source_default_linuxdo_grant_on_signup": false, + "auth_source_default_linuxdo_grant_on_first_bind": false, + "auth_source_default_oidc_balance": 0, + "auth_source_default_oidc_concurrency": 5, + "auth_source_default_oidc_subscriptions": [], + "auth_source_default_oidc_grant_on_signup": false, + "auth_source_default_oidc_grant_on_first_bind": false, + "auth_source_default_wechat_balance": 0, + "auth_source_default_wechat_concurrency": 5, + "auth_source_default_wechat_subscriptions": [], + "auth_source_default_wechat_grant_on_signup": false, + "auth_source_default_wechat_grant_on_first_bind": false, + "force_email_on_third_party_signup": false, "default_concurrency": 5, "default_balance": 1.25, "default_subscriptions": [], @@ -592,6 +618,11 @@ func TestAPIContracts(t *testing.T) { "enable_fingerprint_unification": true, "enable_metadata_passthrough": false, "web_search_emulation_enabled": false, + "payment_visible_method_alipay_source": "easypay_alipay", + "payment_visible_method_wxpay_source": "official_wxpay", + "payment_visible_method_alipay_enabled": true, + "payment_visible_method_wxpay_enabled": false, + "openai_advanced_scheduler_enabled": true, "custom_menu_items": [], "custom_endpoints": [], "payment_enabled": false, @@ -858,6 +889,18 @@ func (r *stubUserRepo) Delete(ctx context.Context, id int64) error { return errors.New("not implemented") } +func (r *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) { + return nil, nil +} + +func (r *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error { + return errors.New("not implemented") +} + func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } @@ -894,6 +937,18 @@ func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64 return errors.New("not implemented") } +func (r *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} + +func (r *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) { + return nil, nil +} + func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { return errors.New("not implemented") } diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index 487fb5f1..e2eae0b4 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -82,6 +82,12 @@ func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + panic("unexpected") +} func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { panic("unexpected") } diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index a2644fcd..02a64c1c 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -566,6 +566,16 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet normalizedWhitelist = []string{} } settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist + alipaySource, err := normalizeVisibleMethodSettingSource("alipay", settings.PaymentVisibleMethodAlipaySource, settings.PaymentVisibleMethodAlipayEnabled) + if err != nil { + return err + } + wxpaySource, err := normalizeVisibleMethodSettingSource("wxpay", settings.PaymentVisibleMethodWxpaySource, settings.PaymentVisibleMethodWxpayEnabled) + if err != nil { + return err + } + settings.PaymentVisibleMethodAlipaySource = alipaySource + settings.PaymentVisibleMethodWxpaySource = wxpaySource updates := make(map[string]string) @@ -701,6 +711,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification) updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough) updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning) + updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource + updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource + updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled) + updates[SettingPaymentVisibleMethodWxpayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodWxpayEnabled) + updates[openAIAdvancedSchedulerSettingKey] = strconv.FormatBool(settings.OpenAIAdvancedSchedulerEnabled) // Balance low notification updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled) @@ -730,6 +745,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet cchSigning: settings.EnableCCHSigning, expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), }) + openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey) + openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{ + enabled: settings.OpenAIAdvancedSchedulerEnabled, + expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(), + }) if s.onUpdate != nil { s.onUpdate() // Invalidate cache after settings update } @@ -1137,7 +1157,12 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyMaxClaudeCodeVersion: "", // 分组隔离(默认不允许未分组 Key 调度) - SettingKeyAllowUngroupedKeyScheduling: "false", + SettingKeyAllowUngroupedKeyScheduling: "false", + SettingPaymentVisibleMethodAlipaySource: "", + SettingPaymentVisibleMethodWxpaySource: "", + SettingPaymentVisibleMethodAlipayEnabled: "false", + SettingPaymentVisibleMethodWxpayEnabled: "false", + openAIAdvancedSchedulerSettingKey: "false", } return s.settingRepo.SetMultiple(ctx, defaults) @@ -1429,6 +1454,11 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0 } } + result.PaymentVisibleMethodAlipaySource = NormalizeVisibleMethodSource("alipay", settings[SettingPaymentVisibleMethodAlipaySource]) + result.PaymentVisibleMethodWxpaySource = NormalizeVisibleMethodSource("wxpay", settings[SettingPaymentVisibleMethodWxpaySource]) + result.PaymentVisibleMethodAlipayEnabled = settings[SettingPaymentVisibleMethodAlipayEnabled] == "true" + result.PaymentVisibleMethodWxpayEnabled = settings[SettingPaymentVisibleMethodWxpayEnabled] == "true" + result.OpenAIAdvancedSchedulerEnabled = settings[openAIAdvancedSchedulerSettingKey] == "true" // Balance low notification result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true" @@ -1458,6 +1488,28 @@ func isFalseSettingValue(value string) bool { } } +func normalizeVisibleMethodSettingSource(method, source string, enabled bool) (string, error) { + source = strings.TrimSpace(source) + if source == "" { + if enabled { + return "", infraerrors.BadRequest( + "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", + fmt.Sprintf("%s source is required when the visible method is enabled", method), + ) + } + return "", nil + } + + normalized := NormalizeVisibleMethodSource(method, source) + if normalized == "" { + return "", infraerrors.BadRequest( + "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", + fmt.Sprintf("%s source must be one of the supported payment providers", method), + ) + } + return normalized, nil +} + func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting { raw = strings.TrimSpace(raw) if raw == "" { diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go index e62218b4..9dc0ca59 100644 --- a/backend/internal/service/setting_service_update_test.go +++ b/backend/internal/service/setting_service_update_test.go @@ -223,3 +223,34 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) { require.Equal(t, "1000", repo.updates[SettingKeyTableDefaultPageSize]) require.Equal(t, "[20,100]", repo.updates[SettingKeyTablePageSizeOptions]) } + +func TestSettingService_UpdateSettings_PaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + PaymentVisibleMethodAlipaySource: "alipay", + PaymentVisibleMethodWxpaySource: "easypay", + PaymentVisibleMethodAlipayEnabled: true, + PaymentVisibleMethodWxpayEnabled: false, + OpenAIAdvancedSchedulerEnabled: true, + }) + require.NoError(t, err) + require.Equal(t, VisibleMethodSourceOfficialAlipay, repo.updates[SettingPaymentVisibleMethodAlipaySource]) + require.Equal(t, VisibleMethodSourceEasyPayWechat, repo.updates[SettingPaymentVisibleMethodWxpaySource]) + require.Equal(t, "true", repo.updates[SettingPaymentVisibleMethodAlipayEnabled]) + require.Equal(t, "false", repo.updates[SettingPaymentVisibleMethodWxpayEnabled]) + require.Equal(t, "true", repo.updates[openAIAdvancedSchedulerSettingKey]) +} + +func TestSettingService_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + PaymentVisibleMethodAlipaySource: "not-a-provider", + }) + require.Error(t, err) + require.Equal(t, "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", infraerrors.Reason(err)) + require.Nil(t, repo.updates) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 72db4e31..9bd461f9 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -110,6 +110,15 @@ type SystemSettings struct { // Web Search Emulation WebSearchEmulationEnabled bool // 是否启用 web search 模拟 + // Payment visible method routing + PaymentVisibleMethodAlipaySource string + PaymentVisibleMethodWxpaySource string + PaymentVisibleMethodAlipayEnabled bool + PaymentVisibleMethodWxpayEnabled bool + + // OpenAI account scheduling + OpenAIAdvancedSchedulerEnabled bool + // Balance low notification BalanceLowNotifyEnabled bool BalanceLowNotifyThreshold float64 diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 89f0362e..e13fb95d 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -103,6 +103,12 @@ func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) er func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { return nil, nil } +func (m *mockUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} +func (m *mockUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + return nil, nil +} func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } From 4f6966d7b3979809da9238498bf403ced840d20b Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Tue, 21 Apr 2026 00:05:42 +0800 Subject: [PATCH 062/326] frontend: route wechat oauth entry by public settings --- frontend/src/api/auth.ts | 55 ++++++ .../components/auth/WechatOAuthSection.vue | 62 ++++++- .../auth/__tests__/WechatOAuthSection.spec.ts | 170 ++++++++++++++++-- 3 files changed, 261 insertions(+), 26 deletions(-) diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index 6a2feb87..6c877d76 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -349,6 +349,61 @@ export async function getPublicSettings(): Promise { return data } +export type WeChatOAuthMode = 'open' | 'mp' +export type WeChatOAuthUnavailableReason = + | 'not_configured' + | 'external_browser_required' + | 'wechat_browser_required' + +export interface ResolvedWeChatOAuthStart { + mode: WeChatOAuthMode | null + openEnabled: boolean + mpEnabled: boolean + isWeChatBrowser: boolean + unavailableReason: WeChatOAuthUnavailableReason | null +} + +type WeChatOAuthPublicSettings = { + wechat_oauth_enabled?: boolean + wechat_oauth_open_enabled?: boolean + wechat_oauth_mp_enabled?: boolean +} + +export function resolveWeChatOAuthStart( + settings: WeChatOAuthPublicSettings | null | undefined, + userAgent?: string +): ResolvedWeChatOAuthStart { + const normalizedUserAgent = (userAgent + ?? (typeof navigator !== 'undefined' ? navigator.userAgent : '') + ?? '').trim() + const isWeChatBrowser = /MicroMessenger/i.test(normalizedUserAgent) + const legacyEnabled = settings?.wechat_oauth_enabled ?? false + const openEnabled = typeof settings?.wechat_oauth_open_enabled === 'boolean' + ? settings.wechat_oauth_open_enabled + : legacyEnabled + const mpEnabled = typeof settings?.wechat_oauth_mp_enabled === 'boolean' + ? settings.wechat_oauth_mp_enabled + : legacyEnabled + + if (isWeChatBrowser) { + if (mpEnabled) { + return { mode: 'mp', openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: null } + } + if (openEnabled) { + return { mode: null, openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: 'external_browser_required' } + } + return { mode: null, openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: 'not_configured' } + } + + if (openEnabled) { + return { mode: 'open', openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: null } + } + if (mpEnabled) { + return { mode: null, openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: 'wechat_browser_required' } + } + return { mode: null, openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: 'not_configured' } +} + /** * Send verification code to email * @param request - Email and optional Turnstile token diff --git a/frontend/src/components/auth/WechatOAuthSection.vue b/frontend/src/components/auth/WechatOAuthSection.vue index 94e20222..01bbd180 100644 --- a/frontend/src/components/auth/WechatOAuthSection.vue +++ b/frontend/src/components/auth/WechatOAuthSection.vue @@ -1,6 +1,6 @@ diff --git a/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts b/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts new file mode 100644 index 00000000..5e6b0ae0 --- /dev/null +++ b/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts @@ -0,0 +1,243 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' +import { defineComponent, h } from 'vue' + +import AuthIdentityMigrationReportsView from '../AuthIdentityMigrationReportsView.vue' + +const { getAuthIdentityMigrationReportSummary, listAuthIdentityMigrationReports, resolveAuthIdentityMigrationReport } = vi.hoisted(() => ({ + getAuthIdentityMigrationReportSummary: vi.fn(), + listAuthIdentityMigrationReports: vi.fn(), + resolveAuthIdentityMigrationReport: vi.fn(), +})) + +const { showError, showSuccess } = vi.hoisted(() => ({ + showError: vi.fn(), + showSuccess: vi.fn(), +})) + +vi.mock('@/api/admin', () => ({ + adminAPI: { + users: { + getAuthIdentityMigrationReportSummary, + listAuthIdentityMigrationReports, + resolveAuthIdentityMigrationReport, + }, + }, +})) + +vi.mock('@/stores/app', () => ({ + useAppStore: () => ({ + showError, + showSuccess, + }), +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + locale: { value: 'en' }, + t: (key: string) => key, + }), + } +}) + +vi.mock('@/utils/format', () => ({ + formatDateTime: (value: string | null | undefined) => value ?? '', +})) + +const sampleReport = { + id: 1, + report_type: 'oidc_synthetic_email_requires_manual_recovery', + report_key: 'legacy@example.invalid', + details: { + user_id: 42, + legacy_email: 'legacy@example.invalid', + provider_key: 'https://issuer.example', + provider_subject: 'subject-123', + }, + created_at: '2026-04-20T01:02:03Z', + resolved_at: null, + resolved_by_user_id: null, + resolution_note: '', +} + +const summaryResponse = { + total: 2, + open_total: 1, + resolved_total: 1, + by_type: { + oidc_synthetic_email_requires_manual_recovery: 2, + }, +} + +const listResponse = { + items: [sampleReport], + total: 1, + page: 1, + page_size: 20, + pages: 1, +} + +const AppLayoutStub = defineComponent({ + setup(_, { slots }) { + return () => h('div', slots.default?.()) + }, +}) + +const TablePageLayoutStub = defineComponent({ + setup(_, { slots }) { + return () => h('div', [ + slots.actions?.(), + slots.filters?.(), + slots.table?.(), + slots.default?.(), + slots.pagination?.(), + ]) + }, +}) + +const DataTableStub = defineComponent({ + props: { + columns: { type: Array, default: () => [] }, + data: { type: Array, default: () => [] }, + loading: { type: Boolean, default: false }, + }, + setup(props, { slots }) { + return () => h('div', { 'data-test': 'data-table' }, [ + props.loading + ? h('div', 'loading') + : (props.data as Array>).map((row) => + h( + 'div', + { key: String(row.id ?? row.report_key) }, + (props.columns as Array<{ key: string }>).map((column) => { + const slot = slots[`cell-${column.key}`] + return h( + 'div', + { key: column.key, [`data-test-cell`]: `${String(row.id)}-${column.key}` }, + slot + ? slot({ row, value: row[column.key] }) + : String(row[column.key] ?? '') + ) + }) + ) + ), + ]) + }, +}) + +const PaginationStub = defineComponent({ + props: { + total: { type: Number, required: true }, + page: { type: Number, required: true }, + pageSize: { type: Number, required: true }, + }, + emits: ['update:page', 'update:pageSize'], + setup(props, { emit }) { + return () => h('div', { 'data-test': 'pagination' }, [ + h('button', { + type: 'button', + 'data-test': 'next-page', + onClick: () => emit('update:page', props.page + 1), + }, 'next'), + h('button', { + type: 'button', + 'data-test': 'page-size-50', + onClick: () => emit('update:pageSize', 50), + }, '50'), + ]) + }, +}) + +describe('AuthIdentityMigrationReportsView', () => { + beforeEach(() => { + getAuthIdentityMigrationReportSummary.mockReset() + listAuthIdentityMigrationReports.mockReset() + resolveAuthIdentityMigrationReport.mockReset() + showError.mockReset() + showSuccess.mockReset() + + getAuthIdentityMigrationReportSummary.mockResolvedValue(summaryResponse) + listAuthIdentityMigrationReports.mockResolvedValue(listResponse) + resolveAuthIdentityMigrationReport.mockResolvedValue({ + ...sampleReport, + resolved_at: '2026-04-20T02:00:00Z', + resolved_by_user_id: 100, + resolution_note: 'resolved by admin', + }) + }) + + const mountView = () => + mount(AuthIdentityMigrationReportsView, { + global: { + stubs: { + AppLayout: AppLayoutStub, + TablePageLayout: TablePageLayoutStub, + DataTable: DataTableStub, + Pagination: PaginationStub, + Icon: true, + }, + }, + }) + + it('loads summary and first page of reports on mount', async () => { + const wrapper = mountView() + + await flushPromises() + + expect(getAuthIdentityMigrationReportSummary).toHaveBeenCalledTimes(1) + expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({ + page: 1, + pageSize: 20, + reportType: '', + }) + expect(wrapper.get('[data-test="summary-total"]').text()).toContain('2') + expect(wrapper.get('[data-test="summary-open"]').text()).toContain('1') + expect(wrapper.get('[data-test="summary-resolved"]').text()).toContain('1') + expect(wrapper.text()).toContain('legacy@example.invalid') + }) + + it('reloads list when the report type filter changes', async () => { + const wrapper = mountView() + + await flushPromises() + + listAuthIdentityMigrationReports.mockClear() + + await wrapper.get('[data-test="report-type-filter"]').setValue( + 'oidc_synthetic_email_requires_manual_recovery' + ) + await flushPromises() + + expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({ + page: 1, + pageSize: 20, + reportType: 'oidc_synthetic_email_requires_manual_recovery', + }) + }) + + it('submits resolve note for the selected report and refreshes data', async () => { + const wrapper = mountView() + + await flushPromises() + + getAuthIdentityMigrationReportSummary.mockClear() + listAuthIdentityMigrationReports.mockClear() + + await wrapper.get('[data-test="select-report-1"]').trigger('click') + await wrapper.get('[data-test="resolution-note"]').setValue('resolved by admin') + await wrapper.get('[data-test="resolve-submit"]').trigger('click') + await flushPromises() + + expect(resolveAuthIdentityMigrationReport).toHaveBeenCalledWith(1, 'resolved by admin') + expect(showSuccess).toHaveBeenCalled() + expect(getAuthIdentityMigrationReportSummary).toHaveBeenCalledTimes(1) + expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({ + page: 1, + pageSize: 20, + reportType: '', + }) + }) +}) From 9204145746623d382f2fe5f9f2f44723042942dd Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Tue, 21 Apr 2026 00:11:03 +0800 Subject: [PATCH 065/326] Close profile identity and avatar loop --- backend/internal/handler/auth_handler.go | 1 + .../internal/handler/auth_linuxdo_oauth.go | 2 +- .../handler/auth_oauth_pending_flow.go | 15 +- .../handler/auth_oauth_pending_flow_test.go | 145 ++++++++++++- backend/internal/handler/auth_oidc_oauth.go | 2 +- backend/internal/handler/auth_wechat_oauth.go | 2 +- backend/internal/handler/user_handler.go | 112 +++++++++- backend/internal/handler/user_handler_test.go | 79 +++++++ backend/internal/service/user_service.go | 75 +++++-- frontend/src/api/user.ts | 1 + .../user/profile/ProfileInfoCard.vue | 202 +++++++++++++++++- .../profile/__tests__/ProfileInfoCard.spec.ts | 172 +++++++++++++++ frontend/src/i18n/locales/en.ts | 14 ++ frontend/src/i18n/locales/zh.ts | 14 ++ 14 files changed, 801 insertions(+), 35 deletions(-) create mode 100644 frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index e4697609..b984a436 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -296,6 +296,7 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { c.Request.Context(), h.entClient(), h.authService, + h.userService, pendingSession, decision, &user.ID, diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index a0760a3b..175b1e1f 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -495,7 +495,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil { + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) return } diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 1b3c1380..94186858 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -852,6 +852,7 @@ func applyPendingOAuthBinding( ctx context.Context, client *dbent.Client, authService *service.AuthService, + userService *service.UserService, session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision, overrideUserID *int64, @@ -938,6 +939,12 @@ func applyPendingOAuthBinding( } } + if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" && userService != nil { + if _, err := userService.SetAvatar(txCtx, targetUserID, adoptedAvatarURL); err != nil { + return err + } + } + return tx.Commit() } @@ -945,6 +952,7 @@ func applyPendingOAuthAdoption( ctx context.Context, client *dbent.Client, authService *service.AuthService, + userService *service.UserService, session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision, overrideUserID *int64, @@ -953,6 +961,7 @@ func applyPendingOAuthAdoption( ctx, client, authService, + userService, session, decision, overrideUserID, @@ -1092,7 +1101,7 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) { }) return } - if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID, true, true); err != nil { + if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID, true, true); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) return } @@ -1188,7 +1197,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) response.ErrorFrom(c, err) return } - if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, session, decision, &user.ID, true, false); err != nil { + if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) return } @@ -1278,7 +1287,7 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { response.ErrorFrom(c, err) return } - if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, session.TargetUserID); err != nil { + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, session.TargetUserID); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) return } diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 8c468fdc..2521186e 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -152,6 +152,11 @@ func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecisio require.Equal(t, "Alice Example", identity.Metadata["display_name"]) require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"]) + avatar := loadUserAvatarRecord(t, client, userEntity.ID) + require.NotNil(t, avatar) + require.Equal(t, "remote_url", avatar.StorageProvider) + require.Equal(t, "https://cdn.example/alice.png", avatar.URL) + decision, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). Only(ctx) @@ -1242,6 +1247,18 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants ( UNIQUE(user_id, provider_type, grant_reason) )`) require.NoError(t, err) + _, err = db.Exec(` +CREATE TABLE IF NOT EXISTS user_avatars ( + user_id INTEGER PRIMARY KEY, + storage_provider TEXT NOT NULL, + storage_key TEXT NOT NULL DEFAULT '', + url TEXT NOT NULL, + content_type TEXT NOT NULL DEFAULT '', + byte_size INTEGER NOT NULL DEFAULT 0, + sha256 TEXT NOT NULL DEFAULT '', + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +)`) + require.NoError(t, err) drv := entsql.OpenDB(dialect.SQLite, db) client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) @@ -1492,6 +1509,35 @@ func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[strin return payload } +type oauthPendingFlowAvatarRecord struct { + StorageProvider string + URL string +} + +func loadUserAvatarRecord(t *testing.T, client *dbent.Client, userID int64) *oauthPendingFlowAvatarRecord { + t.Helper() + + var rows entsql.Rows + err := client.Driver().Query( + context.Background(), + `SELECT storage_provider, url FROM user_avatars WHERE user_id = ?`, + []any{userID}, + &rows, + ) + require.NoError(t, err) + defer rows.Close() + + if !rows.Next() { + require.NoError(t, rows.Err()) + return nil + } + + var record oauthPendingFlowAvatarRecord + require.NoError(t, rows.Scan(&record.StorageProvider, &record.URL)) + require.NoError(t, rows.Err()) + return &record +} + func countProviderGrantRecords( t *testing.T, client *dbent.Client, @@ -1604,16 +1650,95 @@ func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error { return r.client.User.DeleteOneID(id).Exec(ctx) } -func (r *oauthPendingFlowUserRepo) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) { - return nil, nil +func (r *oauthPendingFlowUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) { + driver := r.client.Driver() + if tx := dbent.TxFromContext(ctx); tx != nil { + driver = tx.Client().Driver() + } + + var rows entsql.Rows + if err := driver.Query( + ctx, + `SELECT storage_provider, storage_key, url, content_type, byte_size, sha256 FROM user_avatars WHERE user_id = ?`, + []any{userID}, + &rows, + ); err != nil { + return nil, err + } + defer rows.Close() + + if !rows.Next() { + return nil, rows.Err() + } + + var avatar service.UserAvatar + if err := rows.Scan( + &avatar.StorageProvider, + &avatar.StorageKey, + &avatar.URL, + &avatar.ContentType, + &avatar.ByteSize, + &avatar.SHA256, + ); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return &avatar, nil } -func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) { - panic("unexpected UpsertUserAvatar call") +func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + driver := r.client.Driver() + if tx := dbent.TxFromContext(ctx); tx != nil { + driver = tx.Client().Driver() + } + + var result entsql.Result + if err := driver.Exec( + ctx, + `INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at) +VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP) +ON CONFLICT(user_id) DO UPDATE SET + storage_provider = excluded.storage_provider, + storage_key = excluded.storage_key, + url = excluded.url, + content_type = excluded.content_type, + byte_size = excluded.byte_size, + sha256 = excluded.sha256, + updated_at = CURRENT_TIMESTAMP`, + []any{ + userID, + input.StorageProvider, + input.StorageKey, + input.URL, + input.ContentType, + input.ByteSize, + input.SHA256, + }, + &result, + ); err != nil { + return nil, err + } + + return &service.UserAvatar{ + StorageProvider: input.StorageProvider, + StorageKey: input.StorageKey, + URL: input.URL, + ContentType: input.ContentType, + ByteSize: input.ByteSize, + SHA256: input.SHA256, + }, nil } -func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(context.Context, int64) error { - return nil +func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error { + driver := r.client.Driver() + if tx := dbent.TxFromContext(ctx); tx != nil { + driver = tx.Client().Driver() + } + + var result entsql.Result + return driver.Exec(ctx, `DELETE FROM user_avatars WHERE user_id = ?`, []any{userID}, &result) } func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { @@ -1636,6 +1761,14 @@ func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int panic("unexpected UpdateConcurrency call") } +func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} + +func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + return nil, nil +} + func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx) return count > 0, err diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 70424ec5..0f9f1895 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -537,7 +537,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil { + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) return } diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 4d25a763..b078b804 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -517,7 +517,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil { + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) return } diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 9dcff828..b1ade5c0 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -2,6 +2,7 @@ package handler import ( "context" + "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -43,8 +44,24 @@ type UpdateProfileRequest struct { type userProfileResponse struct { dto.User - AvatarURL string `json:"avatar_url,omitempty"` - Identities service.UserIdentitySummarySet `json:"identities"` + AvatarURL string `json:"avatar_url,omitempty"` + AvatarSource *userProfileSourceContext `json:"avatar_source,omitempty"` + UsernameSource *userProfileSourceContext `json:"username_source,omitempty"` + DisplayNameSource *userProfileSourceContext `json:"display_name_source,omitempty"` + NicknameSource *userProfileSourceContext `json:"nickname_source,omitempty"` + ProfileSources map[string]*userProfileSourceContext `json:"profile_sources,omitempty"` + Identities service.UserIdentitySummarySet `json:"identities"` + AuthBindings map[string]service.UserIdentitySummary `json:"auth_bindings"` + IdentityBindings map[string]service.UserIdentitySummary `json:"identity_bindings"` + EmailBound bool `json:"email_bound"` + LinuxDoBound bool `json:"linuxdo_bound"` + OIDCBound bool `json:"oidc_bound"` + WeChatBound bool `json:"wechat_bound"` +} + +type userProfileSourceContext struct { + Provider string `json:"provider,omitempty"` + Source string `json:"source,omitempty"` } // GetProfile handles getting user profile @@ -335,9 +352,94 @@ func userProfileResponseFromService(user *service.User, identities service.UserI if base == nil { return userProfileResponse{} } + bindings := userProfileBindingMap(identities) + profileSources, avatarSource, usernameSource := inferUserProfileSources(user, identities) return userProfileResponse{ - User: *base, - AvatarURL: user.AvatarURL, - Identities: identities, + User: *base, + AvatarURL: user.AvatarURL, + AvatarSource: avatarSource, + UsernameSource: usernameSource, + DisplayNameSource: usernameSource, + NicknameSource: usernameSource, + ProfileSources: profileSources, + Identities: identities, + AuthBindings: bindings, + IdentityBindings: bindings, + EmailBound: identities.Email.Bound, + LinuxDoBound: identities.LinuxDo.Bound, + OIDCBound: identities.OIDC.Bound, + WeChatBound: identities.WeChat.Bound, + } +} + +func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string]service.UserIdentitySummary { + return map[string]service.UserIdentitySummary{ + "email": identities.Email, + "linuxdo": identities.LinuxDo, + "oidc": identities.OIDC, + "wechat": identities.WeChat, + } +} + +func inferUserProfileSources(user *service.User, identities service.UserIdentitySummarySet) ( + map[string]*userProfileSourceContext, + *userProfileSourceContext, + *userProfileSourceContext, +) { + if user == nil { + return nil, nil, nil + } + + thirdParty := thirdPartyIdentityProviders(identities) + var avatarSource *userProfileSourceContext + if strings.TrimSpace(user.AvatarURL) != "" && len(thirdParty) == 1 { + avatarSource = buildUserProfileSourceContext(thirdParty[0].Provider) + } + + usernameValue := strings.TrimSpace(user.Username) + var usernameSource *userProfileSourceContext + for _, summary := range thirdParty { + if usernameValue != "" && usernameValue == strings.TrimSpace(summary.DisplayName) { + usernameSource = buildUserProfileSourceContext(summary.Provider) + break + } + } + if usernameSource == nil && usernameValue != "" && len(thirdParty) == 1 { + usernameSource = buildUserProfileSourceContext(thirdParty[0].Provider) + } + + profileSources := map[string]*userProfileSourceContext{} + if avatarSource != nil { + profileSources["avatar"] = avatarSource + } + if usernameSource != nil { + profileSources["username"] = usernameSource + profileSources["display_name"] = usernameSource + profileSources["nickname"] = usernameSource + } + if len(profileSources) == 0 { + return nil, avatarSource, usernameSource + } + return profileSources, avatarSource, usernameSource +} + +func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary { + out := make([]service.UserIdentitySummary, 0, 3) + for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} { + if summary.Bound { + out = append(out, summary) + } + } + return out +} + +func buildUserProfileSourceContext(provider string) *userProfileSourceContext { + provider = strings.TrimSpace(provider) + if provider == "" { + return nil + } + return &userProfileSourceContext{ + Provider: provider, + Source: provider, } } diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index b71846c1..1216f9c4 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -92,6 +92,12 @@ func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int6 func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } +func (s *userHandlerRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} +func (s *userHandlerRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + return nil, nil +} func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { return nil } @@ -230,6 +236,79 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) { require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/start") } +func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) { + gin.SetMode(gin.TestMode) + + verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC) + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 21, + Email: "legacy-profile@example.com", + Username: "linuxdo-handle", + Role: service.RoleUser, + Status: service.StatusActive, + AvatarURL: "https://cdn.example.com/linuxdo.png", + AvatarSource: "remote_url", + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-21", + VerifiedAt: &verifiedAt, + Metadata: map[string]any{ + "username": "linuxdo-handle", + }, + }, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21}) + + handler.GetProfile(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, true, resp.Data["email_bound"]) + require.Equal(t, true, resp.Data["linuxdo_bound"]) + require.Equal(t, false, resp.Data["oidc_bound"]) + require.Equal(t, false, resp.Data["wechat_bound"]) + require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"]) + + authBindings, ok := resp.Data["auth_bindings"].(map[string]any) + require.True(t, ok) + linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, linuxdoBinding["bound"]) + require.Equal(t, "linuxdo", linuxdoBinding["provider"]) + + identityBindings, ok := resp.Data["identity_bindings"].(map[string]any) + require.True(t, ok) + emailBinding, ok := identityBindings["email"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, emailBinding["bound"]) + + avatarSource, ok := resp.Data["avatar_source"].(map[string]any) + require.True(t, ok) + require.Equal(t, "linuxdo", avatarSource["provider"]) + + profileSources, ok := resp.Data["profile_sources"].(map[string]any) + require.True(t, ok) + usernameSource, ok := profileSources["username"].(map[string]any) + require.True(t, ok) + require.Equal(t, "linuxdo", usernameSource["provider"]) +} + func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index c52f91bb..c106a3f5 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -65,6 +65,8 @@ type UserRepository interface { List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) + GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) + GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) UpdateBalance(ctx context.Context, id int64, amount float64) error DeductBalance(ctx context.Context, id int64, amount float64) error @@ -159,6 +161,33 @@ type userAuthIdentityReader interface { ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) } +type emailAuthIdentitySynchronizer interface { + EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error + ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error +} + +func ensureEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, email string) error { + syncer, ok := repo.(emailAuthIdentitySynchronizer) + if !ok { + return nil + } + return syncer.EnsureEmailAuthIdentity(ctx, userID, email) +} + +func replaceEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, oldEmail, newEmail string) error { + oldNormalized := strings.ToLower(strings.TrimSpace(oldEmail)) + newNormalized := strings.ToLower(strings.TrimSpace(newEmail)) + if oldNormalized == newNormalized { + return nil + } + + syncer, ok := repo.(emailAuthIdentitySynchronizer) + if !ok { + return nil + } + return syncer.ReplaceEmailAuthIdentity(ctx, userID, oldEmail, newEmail) +} + // ChangePasswordRequest 修改密码请求 type ChangePasswordRequest struct { CurrentPassword string `json:"current_password"` @@ -252,6 +281,7 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat return nil, fmt.Errorf("get user: %w", err) } oldConcurrency := user.Concurrency + oldEmail := user.Email // 更新字段 if req.Email != nil { @@ -271,24 +301,11 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat } if req.AvatarURL != nil { - avatarValue := strings.TrimSpace(*req.AvatarURL) - switch { - case avatarValue == "": - if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil { - return nil, fmt.Errorf("delete avatar: %w", err) - } - applyUserAvatar(user, nil) - default: - avatarInput, err := normalizeUserAvatarInput(avatarValue) - if err != nil { - return nil, err - } - avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput) - if err != nil { - return nil, fmt.Errorf("upsert avatar: %w", err) - } - applyUserAvatar(user, avatar) + avatar, err := s.SetAvatar(ctx, userID, *req.AvatarURL) + if err != nil { + return nil, err } + applyUserAvatar(user, avatar) } if req.Concurrency != nil { @@ -309,6 +326,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat if err := s.userRepo.Update(ctx, user); err != nil { return nil, fmt.Errorf("update user: %w", err) } + if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil { + return nil, fmt.Errorf("sync email auth identity: %w", err) + } if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } @@ -316,6 +336,27 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat return user, nil } +func (s *UserService) SetAvatar(ctx context.Context, userID int64, raw string) (*UserAvatar, error) { + avatarValue := strings.TrimSpace(raw) + if avatarValue == "" { + if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil { + return nil, fmt.Errorf("delete avatar: %w", err) + } + return nil, nil + } + + avatarInput, err := normalizeUserAvatarInput(avatarValue) + if err != nil { + return nil, err + } + + avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput) + if err != nil { + return nil, fmt.Errorf("upsert avatar: %w", err) + } + return avatar, nil +} + func applyUserAvatar(user *User, avatar *UserAvatar) { if user == nil { return diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts index 1f6e4cd9..c7b1e503 100644 --- a/frontend/src/api/user.ts +++ b/frontend/src/api/user.ts @@ -23,6 +23,7 @@ export async function getProfile(): Promise { */ export async function updateProfile(profile: { username?: string + avatar_url?: string | null balance_notify_enabled?: boolean balance_notify_threshold?: number | null balance_notify_extra_emails?: NotifyEmailEntry[] diff --git a/frontend/src/components/user/profile/ProfileInfoCard.vue b/frontend/src/components/user/profile/ProfileInfoCard.vue index e82ae229..d6273431 100644 --- a/frontend/src/components/user/profile/ProfileInfoCard.vue +++ b/frontend/src/components/user/profile/ProfileInfoCard.vue @@ -61,6 +61,71 @@
+
+
+
+

+ {{ t('profile.avatar.title') }} +

+

+ {{ t('profile.avatar.description') }} +

+
+ +
+ +
+ + -
- - -
- -
-

- {{ copy.remediationTitle }} -

-

- {{ copy.remediationSubtitle }} -

- -
-
- - -
- -
- - -
- -
- - -
- -
- - -
- -
- - -
- - -
-
- - - -
- - - diff --git a/frontend/src/views/admin/UsersView.vue b/frontend/src/views/admin/UsersView.vue index 93cfdbbe..39c9b377 100644 --- a/frontend/src/views/admin/UsersView.vue +++ b/frontend/src/views/admin/UsersView.vue @@ -712,8 +712,8 @@ const allColumns = computed(() => [ { key: 'usage', label: t('admin.users.columns.usage'), sortable: false }, { key: 'concurrency', label: t('admin.users.columns.concurrency'), sortable: true }, { key: 'status', label: t('admin.users.columns.status'), sortable: true }, - { key: 'last_used_at', label: t('admin.users.columns.lastUsed'), sortable: true }, { key: 'last_active_at', label: t('admin.users.columns.lastActive'), sortable: true }, + { key: 'last_used_at', label: t('admin.users.columns.lastUsed'), sortable: true }, { key: 'created_at', label: t('admin.users.columns.created'), sortable: true }, { key: 'actions', label: t('admin.users.columns.actions'), sortable: false } ]) diff --git a/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts b/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts deleted file mode 100644 index 20f57fa1..00000000 --- a/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts +++ /dev/null @@ -1,303 +0,0 @@ -import { beforeEach, describe, expect, it, vi } from 'vitest' -import { flushPromises, mount } from '@vue/test-utils' -import { defineComponent, h } from 'vue' - -import AuthIdentityMigrationReportsView from '../AuthIdentityMigrationReportsView.vue' - -const { - bindUserAuthIdentity, - getAuthIdentityMigrationReportSummary, - listAuthIdentityMigrationReports, - resolveAuthIdentityMigrationReport, -} = vi.hoisted(() => ({ - bindUserAuthIdentity: vi.fn(), - getAuthIdentityMigrationReportSummary: vi.fn(), - listAuthIdentityMigrationReports: vi.fn(), - resolveAuthIdentityMigrationReport: vi.fn(), -})) - -const { showError, showSuccess } = vi.hoisted(() => ({ - showError: vi.fn(), - showSuccess: vi.fn(), -})) - -vi.mock('@/api/admin', () => ({ - adminAPI: { - users: { - bindUserAuthIdentity, - getAuthIdentityMigrationReportSummary, - listAuthIdentityMigrationReports, - resolveAuthIdentityMigrationReport, - }, - }, -})) - -vi.mock('@/stores/app', () => ({ - useAppStore: () => ({ - showError, - showSuccess, - }), -})) - -vi.mock('vue-i18n', async () => { - const actual = await vi.importActual('vue-i18n') - return { - ...actual, - useI18n: () => ({ - locale: { value: 'en' }, - t: (key: string) => key, - }), - } -}) - -vi.mock('@/utils/format', () => ({ - formatDateTime: (value: string | null | undefined) => value ?? '', -})) - -const sampleReport = { - id: 1, - report_type: 'oidc_synthetic_email_requires_manual_recovery', - report_key: 'legacy@example.invalid', - details: { - user_id: 42, - legacy_email: 'legacy@example.invalid', - provider_key: 'https://issuer.example', - provider_subject: 'subject-123', - }, - created_at: '2026-04-20T01:02:03Z', - resolved_at: null, - resolved_by_user_id: null, - resolution_note: '', -} - -const summaryResponse = { - total: 2, - open_total: 1, - resolved_total: 1, - by_type: { - oidc_synthetic_email_requires_manual_recovery: 2, - }, -} - -const listResponse = { - items: [sampleReport], - total: 1, - page: 1, - page_size: 20, - pages: 1, -} - -const AppLayoutStub = defineComponent({ - setup(_, { slots }) { - return () => h('div', slots.default?.()) - }, -}) - -const TablePageLayoutStub = defineComponent({ - setup(_, { slots }) { - return () => h('div', [ - slots.actions?.(), - slots.filters?.(), - slots.table?.(), - slots.default?.(), - slots.pagination?.(), - ]) - }, -}) - -const DataTableStub = defineComponent({ - props: { - columns: { type: Array, default: () => [] }, - data: { type: Array, default: () => [] }, - loading: { type: Boolean, default: false }, - }, - setup(props, { slots }) { - return () => h('div', { 'data-test': 'data-table' }, [ - props.loading - ? h('div', 'loading') - : (props.data as Array>).map((row) => - h( - 'div', - { key: String(row.id ?? row.report_key) }, - (props.columns as Array<{ key: string }>).map((column) => { - const slot = slots[`cell-${column.key}`] - return h( - 'div', - { key: column.key, [`data-test-cell`]: `${String(row.id)}-${column.key}` }, - slot - ? slot({ row, value: row[column.key] }) - : String(row[column.key] ?? '') - ) - }) - ) - ), - ]) - }, -}) - -const PaginationStub = defineComponent({ - props: { - total: { type: Number, required: true }, - page: { type: Number, required: true }, - pageSize: { type: Number, required: true }, - }, - emits: ['update:page', 'update:pageSize'], - setup(props, { emit }) { - return () => h('div', { 'data-test': 'pagination' }, [ - h('button', { - type: 'button', - 'data-test': 'next-page', - onClick: () => emit('update:page', props.page + 1), - }, 'next'), - h('button', { - type: 'button', - 'data-test': 'page-size-50', - onClick: () => emit('update:pageSize', 50), - }, '50'), - ]) - }, -}) - -describe('AuthIdentityMigrationReportsView', () => { - beforeEach(() => { - getAuthIdentityMigrationReportSummary.mockReset() - listAuthIdentityMigrationReports.mockReset() - resolveAuthIdentityMigrationReport.mockReset() - bindUserAuthIdentity.mockReset() - showError.mockReset() - showSuccess.mockReset() - - getAuthIdentityMigrationReportSummary.mockResolvedValue(summaryResponse) - listAuthIdentityMigrationReports.mockResolvedValue(listResponse) - resolveAuthIdentityMigrationReport.mockResolvedValue({ - ...sampleReport, - resolved_at: '2026-04-20T02:00:00Z', - resolved_by_user_id: 100, - resolution_note: 'resolved by admin', - }) - bindUserAuthIdentity.mockResolvedValue({ - identity_id: 77, - provider_type: 'oidc', - provider_key: 'https://issuer.example', - provider_subject: 'subject-123', - }) - }) - - const mountView = () => - mount(AuthIdentityMigrationReportsView, { - global: { - stubs: { - AppLayout: AppLayoutStub, - TablePageLayout: TablePageLayoutStub, - DataTable: DataTableStub, - Pagination: PaginationStub, - Icon: true, - }, - }, - }) - - it('loads summary and first page of reports on mount', async () => { - const wrapper = mountView() - - await flushPromises() - - expect(getAuthIdentityMigrationReportSummary).toHaveBeenCalledTimes(1) - expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({ - page: 1, - pageSize: 20, - reportType: '', - }) - expect(wrapper.get('[data-test="summary-total"]').text()).toContain('2') - expect(wrapper.get('[data-test="summary-open"]').text()).toContain('1') - expect(wrapper.get('[data-test="summary-resolved"]').text()).toContain('1') - expect(wrapper.text()).toContain('legacy@example.invalid') - }) - - it('reloads list when the report type filter changes', async () => { - const wrapper = mountView() - - await flushPromises() - - listAuthIdentityMigrationReports.mockClear() - - await wrapper.get('[data-test="report-type-filter"]').setValue( - 'oidc_synthetic_email_requires_manual_recovery' - ) - await flushPromises() - - expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({ - page: 1, - pageSize: 20, - reportType: 'oidc_synthetic_email_requires_manual_recovery', - }) - }) - - it('submits resolve note for the selected report and refreshes data', async () => { - const wrapper = mountView() - - await flushPromises() - - getAuthIdentityMigrationReportSummary.mockClear() - listAuthIdentityMigrationReports.mockClear() - - await wrapper.get('[data-test="select-report-1"]').trigger('click') - await wrapper.get('[data-test="resolution-note"]').setValue('resolved by admin') - await wrapper.get('[data-test="resolve-submit"]').trigger('click') - await flushPromises() - - expect(resolveAuthIdentityMigrationReport).toHaveBeenCalledWith(1, 'resolved by admin') - expect(showSuccess).toHaveBeenCalled() - expect(getAuthIdentityMigrationReportSummary).toHaveBeenCalledTimes(1) - expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({ - page: 1, - pageSize: 20, - reportType: '', - }) - }) - - it('pre-fills and submits remediation binding for the selected report', async () => { - const wrapper = mountView() - - await flushPromises() - await wrapper.get('[data-test="select-report-1"]').trigger('click') - await flushPromises() - - expect((wrapper.get('[data-test="remediation-user-id"]').element as HTMLInputElement).value).toBe('42') - expect((wrapper.get('[data-test="remediation-provider-type"]').element as HTMLInputElement).value).toBe('oidc') - expect((wrapper.get('[data-test="remediation-provider-key"]').element as HTMLInputElement).value).toBe( - 'https://issuer.example' - ) - expect((wrapper.get('[data-test="remediation-provider-subject"]').element as HTMLInputElement).value).toBe( - 'subject-123' - ) - - await wrapper.get('[data-test="remediation-submit"]').trigger('click') - await flushPromises() - - expect(bindUserAuthIdentity).toHaveBeenCalledWith(42, { - provider_type: 'oidc', - provider_key: 'https://issuer.example', - provider_subject: 'subject-123', - issuer: undefined, - metadata: {}, - }) - expect(showSuccess).toHaveBeenCalled() - }) - - it('keeps report type filter options available from list data when summary fails', async () => { - getAuthIdentityMigrationReportSummary.mockRejectedValueOnce(new Error('summary failed')) - listAuthIdentityMigrationReports.mockResolvedValueOnce(listResponse) - - const wrapper = mountView() - - await flushPromises() - - const options = wrapper - .get('[data-test="report-type-filter"]') - .findAll('option') - .map((node) => node.element.value) - - expect(showError).toHaveBeenCalled() - expect(options).toContain('oidc_synthetic_email_requires_manual_recovery') - }) -}) diff --git a/frontend/src/views/admin/__tests__/UsersView.spec.ts b/frontend/src/views/admin/__tests__/UsersView.spec.ts index d9076777..532d89f1 100644 --- a/frontend/src/views/admin/__tests__/UsersView.spec.ts +++ b/frontend/src/views/admin/__tests__/UsersView.spec.ts @@ -70,7 +70,6 @@ const createAdminUser = (): AdminUser => ({ created_at: '2026-04-17T00:00:00Z', updated_at: '2026-04-17T00:00:00Z', notes: '', - last_login_at: '2026-04-16T01:00:00Z', last_active_at: '2026-04-16T02:00:00Z', last_used_at: '2026-04-17T02:00:00Z', current_concurrency: 0 @@ -113,7 +112,7 @@ describe('admin UsersView', () => { getBatchUserAttributes.mockResolvedValue({ values: {} }) }) - it('shows active and used activity columns, hides last_login_at, and requests last_used_at sort', async () => { + it('shows active, used, and created activity columns in order and requests last_used_at sort', async () => { const wrapper = mount(UsersView, { global: { stubs: { @@ -145,9 +144,9 @@ describe('admin UsersView', () => { await flushPromises() const columns = wrapper.get('[data-test="columns"]').text() - expect(columns).toContain('last_used_at') - expect(columns).toContain('last_active_at') - expect(columns).not.toContain('last_login_at') + const visibleColumns = columns.split(',') + expect(visibleColumns.slice(-4, -1)).toEqual(['last_active_at', 'last_used_at', 'created_at']) + expect(visibleColumns).not.toContain('last_login_at') await wrapper.get('[data-test="sort-last-used"]').trigger('click') await flushPromises() From ee3f158f4e11dff97e0db0dccecc6885dcfd1b96 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Tue, 21 Apr 2026 17:35:12 +0800 Subject: [PATCH 141/326] fix(settings): restore wechat and payment config persistence --- .../internal/handler/admin/setting_handler.go | 99 + backend/internal/handler/auth_wechat_oauth.go | 42 +- .../handler/auth_wechat_oauth_test.go | 100 +- backend/internal/handler/dto/settings.go | 8 + .../handler/setting_handler_public_test.go | 21 +- backend/internal/service/domain_constants.go | 9 + .../service/payment_config_service.go | 9 + .../service/payment_config_service_test.go | 47 +- backend/internal/service/payment_order.go | 16 +- .../service/payment_order_jsapi_test.go | 16 +- .../service/payment_order_result_test.go | 36 +- backend/internal/service/setting_service.go | 168 +- .../service/setting_service_public_test.go | 21 +- .../setting_service_wechat_config_test.go | 77 + backend/internal/service/settings_view.go | 20 + .../__tests__/settings.wechatConnect.spec.ts | 21 + frontend/src/api/admin/settings.ts | 1032 +- frontend/src/views/admin/SettingsView.vue | 9289 ++++++++++------- .../admin/__tests__/SettingsView.spec.ts | 602 +- 19 files changed, 7160 insertions(+), 4473 deletions(-) create mode 100644 backend/internal/service/setting_service_wechat_config_test.go create mode 100644 frontend/src/api/__tests__/settings.wechatConnect.spec.ts diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index f0e91f3a..9bc20771 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -122,6 +122,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { LinuxDoConnectClientID: settings.LinuxDoConnectClientID, LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured, LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL, + WeChatConnectEnabled: settings.WeChatConnectEnabled, + WeChatConnectAppID: settings.WeChatConnectAppID, + WeChatConnectAppSecretConfigured: settings.WeChatConnectAppSecretConfigured, + WeChatConnectMode: settings.WeChatConnectMode, + WeChatConnectScopes: settings.WeChatConnectScopes, + WeChatConnectRedirectURL: settings.WeChatConnectRedirectURL, + WeChatConnectFrontendRedirectURL: settings.WeChatConnectFrontendRedirectURL, OIDCConnectEnabled: settings.OIDCConnectEnabled, OIDCConnectProviderName: settings.OIDCConnectProviderName, OIDCConnectClientID: settings.OIDCConnectClientID, @@ -246,6 +253,15 @@ type UpdateSettingsRequest struct { LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + // WeChat Connect OAuth 登录 + WeChatConnectEnabled bool `json:"wechat_connect_enabled"` + WeChatConnectAppID string `json:"wechat_connect_app_id"` + WeChatConnectAppSecret string `json:"wechat_connect_app_secret"` + WeChatConnectMode string `json:"wechat_connect_mode"` + WeChatConnectScopes string `json:"wechat_connect_scopes"` + WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"` + WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"` + // Generic OIDC OAuth 登录 OIDCConnectEnabled bool `json:"oidc_connect_enabled"` OIDCConnectProviderName string `json:"oidc_connect_provider_name"` @@ -509,6 +525,54 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } + if req.WeChatConnectEnabled { + req.WeChatConnectAppID = strings.TrimSpace(req.WeChatConnectAppID) + req.WeChatConnectAppSecret = strings.TrimSpace(req.WeChatConnectAppSecret) + req.WeChatConnectMode = strings.ToLower(strings.TrimSpace(req.WeChatConnectMode)) + req.WeChatConnectScopes = strings.TrimSpace(req.WeChatConnectScopes) + req.WeChatConnectRedirectURL = strings.TrimSpace(req.WeChatConnectRedirectURL) + req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(req.WeChatConnectFrontendRedirectURL) + + if req.WeChatConnectAppID == "" { + response.BadRequest(c, "WeChat App ID is required when enabled") + return + } + if req.WeChatConnectAppSecret == "" { + if previousSettings.WeChatConnectAppSecret == "" { + response.BadRequest(c, "WeChat App Secret is required when enabled") + return + } + req.WeChatConnectAppSecret = previousSettings.WeChatConnectAppSecret + } + if req.WeChatConnectMode == "" { + req.WeChatConnectMode = "open" + } + switch req.WeChatConnectMode { + case "open", "mp": + default: + response.BadRequest(c, "WeChat mode must be open or mp") + return + } + if req.WeChatConnectScopes == "" { + req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode(req.WeChatConnectMode) + } + if req.WeChatConnectRedirectURL == "" { + response.BadRequest(c, "WeChat Redirect URL is required when enabled") + return + } + if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil { + response.BadRequest(c, "WeChat Redirect URL must be an absolute http(s) URL") + return + } + if req.WeChatConnectFrontendRedirectURL == "" { + req.WeChatConnectFrontendRedirectURL = "/auth/wechat/callback" + } + if err := config.ValidateFrontendRedirectURL(req.WeChatConnectFrontendRedirectURL); err != nil { + response.BadRequest(c, "WeChat Frontend Redirect URL is invalid") + return + } + } + // Generic OIDC 参数验证 if req.OIDCConnectEnabled { req.OIDCConnectProviderName = strings.TrimSpace(req.OIDCConnectProviderName) @@ -857,6 +921,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { LinuxDoConnectClientID: req.LinuxDoConnectClientID, LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, + WeChatConnectEnabled: req.WeChatConnectEnabled, + WeChatConnectAppID: req.WeChatConnectAppID, + WeChatConnectAppSecret: req.WeChatConnectAppSecret, + WeChatConnectMode: req.WeChatConnectMode, + WeChatConnectScopes: req.WeChatConnectScopes, + WeChatConnectRedirectURL: req.WeChatConnectRedirectURL, + WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL, OIDCConnectEnabled: req.OIDCConnectEnabled, OIDCConnectProviderName: req.OIDCConnectProviderName, OIDCConnectClientID: req.OIDCConnectClientID, @@ -1136,6 +1207,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID, LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured, LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL, + WeChatConnectEnabled: updatedSettings.WeChatConnectEnabled, + WeChatConnectAppID: updatedSettings.WeChatConnectAppID, + WeChatConnectAppSecretConfigured: updatedSettings.WeChatConnectAppSecretConfigured, + WeChatConnectMode: updatedSettings.WeChatConnectMode, + WeChatConnectScopes: updatedSettings.WeChatConnectScopes, + WeChatConnectRedirectURL: updatedSettings.WeChatConnectRedirectURL, + WeChatConnectFrontendRedirectURL: updatedSettings.WeChatConnectFrontendRedirectURL, OIDCConnectEnabled: updatedSettings.OIDCConnectEnabled, OIDCConnectProviderName: updatedSettings.OIDCConnectProviderName, OIDCConnectClientID: updatedSettings.OIDCConnectClientID, @@ -1329,6 +1407,27 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL { changed = append(changed, "linuxdo_connect_redirect_url") } + if before.WeChatConnectEnabled != after.WeChatConnectEnabled { + changed = append(changed, "wechat_connect_enabled") + } + if before.WeChatConnectAppID != after.WeChatConnectAppID { + changed = append(changed, "wechat_connect_app_id") + } + if req.WeChatConnectAppSecret != "" { + changed = append(changed, "wechat_connect_app_secret") + } + if before.WeChatConnectMode != after.WeChatConnectMode { + changed = append(changed, "wechat_connect_mode") + } + if before.WeChatConnectScopes != after.WeChatConnectScopes { + changed = append(changed, "wechat_connect_scopes") + } + if before.WeChatConnectRedirectURL != after.WeChatConnectRedirectURL { + changed = append(changed, "wechat_connect_redirect_url") + } + if before.WeChatConnectFrontendRedirectURL != after.WeChatConnectFrontendRedirectURL { + changed = append(changed, "wechat_connect_frontend_redirect_url") + } if before.OIDCConnectEnabled != after.OIDCConnectEnabled { changed = append(changed, "oidc_connect_enabled") } diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 5e697fb5..734fb2ef 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -8,7 +8,6 @@ import ( "io" "net/http" "net/url" - "os" "strconv" "strings" "time" @@ -149,7 +148,7 @@ func (h *AuthHandler) WeChatOAuthStart(c *gin.Context) { // WeChatOAuthCallback exchanges the code with WeChat, resolves openid/unionid, // and stores the result in the unified pending-auth flow. func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { - frontendCallback := wechatOAuthFrontendCallback() + frontendCallback := h.wechatOAuthFrontendCallback(c.Request.Context()) if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) @@ -859,6 +858,10 @@ func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, return wechatOAuthConfig{}, err } + if h == nil || h.settingSvc == nil { + return wechatOAuthConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "wechat oauth settings service not ready") + } + apiBaseURL := "" if h != nil && h.settingSvc != nil { settings, err := h.settingSvc.GetAllSettings(ctx) @@ -867,27 +870,28 @@ func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, } } + effective, err := h.settingSvc.GetWeChatConnectOAuthConfig(ctx) + if err != nil { + return wechatOAuthConfig{}, err + } + if effective.Mode != mode { + return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled") + } + cfg := wechatOAuthConfig{ mode: mode, - redirectURI: resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback"), - frontendCallback: wechatOAuthFrontendCallback(), + appID: strings.TrimSpace(effective.AppID), + appSecret: strings.TrimSpace(effective.AppSecret), + redirectURI: firstNonEmpty(strings.TrimSpace(effective.RedirectURL), resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback")), + frontendCallback: firstNonEmpty(strings.TrimSpace(effective.FrontendRedirectURL), wechatOAuthDefaultFrontendCB), + scope: firstNonEmpty(strings.TrimSpace(effective.Scopes), service.DefaultWeChatConnectScopesForMode(mode)), } switch mode { case "mp": - cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) - cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) cfg.authorizeURL = "https://open.weixin.qq.com/connect/oauth2/authorize" - cfg.scope = "snsapi_userinfo" default: - cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) - cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) cfg.authorizeURL = "https://open.weixin.qq.com/connect/qrconnect" - cfg.scope = "snsapi_login" - } - - if cfg.appID == "" || cfg.appSecret == "" { - return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled") } if strings.TrimSpace(cfg.redirectURI) == "" { return wechatOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured") @@ -896,8 +900,14 @@ func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, return cfg, nil } -func wechatOAuthFrontendCallback() string { - return firstNonEmpty(strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")), wechatOAuthDefaultFrontendCB) +func (h *AuthHandler) wechatOAuthFrontendCallback(ctx context.Context) string { + if h != nil && h.settingSvc != nil { + cfg, err := h.settingSvc.GetWeChatConnectOAuthConfig(ctx) + if err == nil && strings.TrimSpace(cfg.FrontendRedirectURL) != "" { + return strings.TrimSpace(cfg.FrontendRedirectURL) + } + } + return wechatOAuthDefaultFrontendCB } func resolveWeChatOAuthMode(rawMode string, c *gin.Context) (string, error) { diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index cd34f52f..b0fee617 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -33,16 +33,22 @@ import ( ) func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - gin.SetMode(gin.TestMode) + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, map[string]string{ + service.SettingKeyWeChatConnectEnabled: "true", + service.SettingKeyWeChatConnectAppID: "wx-open-app", + service.SettingKeyWeChatConnectAppSecret: "wx-open-secret", + service.SettingKeyWeChatConnectMode: "open", + service.SettingKeyWeChatConnectScopes: "snsapi_login", + service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }) + defer client.Close() recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil) c.Request.Host = "api.example.com" - handler := &AuthHandler{} handler.WeChatOAuthStart(c) require.Equal(t, http.StatusFound, recorder.Code) @@ -60,10 +66,6 @@ func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) { } func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -124,10 +126,6 @@ func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) { } func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "https://app.example.com/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -151,7 +149,7 @@ func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) { wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" - handler, client := newWeChatOAuthTestHandler(t, false) + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback")) defer client.Close() recorder := httptest.NewRecorder() @@ -177,9 +175,6 @@ func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) { } func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) { - t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app") - t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wx-mp-secret") - originalAccessTokenURL := wechatOAuthAccessTokenURL t.Cleanup(func() { wechatOAuthAccessTokenURL = originalAccessTokenURL @@ -196,7 +191,7 @@ func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) defer upstream.Close() wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" - handler, client := newWeChatOAuthTestHandler(t, false) + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback")) defer client.Close() handler.cfg.Totp.EncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" @@ -240,7 +235,6 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test testCases := []struct { name string mode string - appIDEnv string appID string appSecret string openID string @@ -248,7 +242,6 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test { name: "open", mode: "open", - appIDEnv: "WECHAT_OAUTH_OPEN_APP_ID", appID: "wx-open-app", appSecret: "wx-open-secret", openID: "openid-open-123", @@ -256,7 +249,6 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test { name: "mp", mode: "mp", - appIDEnv: "WECHAT_OAUTH_MP_APP_ID", appID: "wx-mp-app", appSecret: "wx-mp-secret", openID: "openid-mp-123", @@ -265,15 +257,6 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - t.Setenv(tc.appIDEnv, tc.appID) - switch tc.mode { - case "open": - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", tc.appSecret) - case "mp": - t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", tc.appSecret) - } - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -297,7 +280,7 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" - handler, client := newWeChatOAuthTestHandler(t, false) + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings(tc.mode, tc.appID, tc.appSecret, "/auth/wechat/callback")) defer client.Close() currentUser, err := client.User.Create(). @@ -354,10 +337,6 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test } func TestWeChatOAuthCallbackBindRejectsCanonicalOwnershipConflict(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -436,10 +415,6 @@ func TestWeChatOAuthCallbackBindRejectsCanonicalOwnershipConflict(t *testing.T) } func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -529,10 +504,6 @@ func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) { } func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -611,10 +582,6 @@ func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *tes } func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -737,10 +704,6 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing } func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -900,10 +863,6 @@ func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testi } func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -1010,6 +969,22 @@ func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing } func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { + return newWeChatOAuthTestHandlerWithSettings(t, invitationEnabled, nil) +} + +func wechatOAuthTestSettings(mode, appID, secret, frontendRedirect string) map[string]string { + return map[string]string{ + service.SettingKeyWeChatConnectEnabled: "true", + service.SettingKeyWeChatConnectAppID: appID, + service.SettingKeyWeChatConnectAppSecret: secret, + service.SettingKeyWeChatConnectMode: mode, + service.SettingKeyWeChatConnectScopes: service.DefaultWeChatConnectScopesForMode(mode), + service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + service.SettingKeyWeChatConnectFrontendRedirectURL: frontendRedirect, + } +} + +func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool, extraSettings map[string]string) (*AuthHandler, *dbent.Client) { t.Helper() db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared") @@ -1036,12 +1011,17 @@ func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandl UserConcurrency: 1, }, } - settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{ - values: map[string]string{ - service.SettingKeyRegistrationEnabled: "true", - service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), - }, - }, cfg) + values := map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), + } + for key, value := range wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "/auth/wechat/callback") { + values[key] = value + } + for key, value := range extraSettings { + values[key] = value + } + settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{values: values}, cfg) authSvc := service.NewAuthService( client, diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index d67b29a0..4c5edfbf 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -51,6 +51,14 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + WeChatConnectEnabled bool `json:"wechat_connect_enabled"` + WeChatConnectAppID string `json:"wechat_connect_app_id"` + WeChatConnectAppSecretConfigured bool `json:"wechat_connect_app_secret_configured"` + WeChatConnectMode string `json:"wechat_connect_mode"` + WeChatConnectScopes string `json:"wechat_connect_scopes"` + WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"` + WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"` + OIDCConnectEnabled bool `json:"oidc_connect_enabled"` OIDCConnectProviderName string `json:"oidc_connect_provider_name"` OIDCConnectClientID string `json:"oidc_connect_client_id"` diff --git a/backend/internal/handler/setting_handler_public_test.go b/backend/internal/handler/setting_handler_public_test.go index b50c982c..628d9341 100644 --- a/backend/internal/handler/setting_handler_public_test.go +++ b/backend/internal/handler/setting_handler_public_test.go @@ -84,12 +84,17 @@ func TestSettingHandler_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t func TestSettingHandler_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *testing.T) { gin.SetMode(gin.TestMode) - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_MP_APP_ID", "") - t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "") - - h := NewSettingHandler(service.NewSettingService(&settingHandlerPublicRepoStub{}, &config.Config{}), "test-version") + h := NewSettingHandler(service.NewSettingService(&settingHandlerPublicRepoStub{ + values: map[string]string{ + service.SettingKeyWeChatConnectEnabled: "true", + service.SettingKeyWeChatConnectAppID: "wx-mp-app", + service.SettingKeyWeChatConnectAppSecret: "wx-mp-secret", + service.SettingKeyWeChatConnectMode: "mp", + service.SettingKeyWeChatConnectScopes: "snsapi_base", + service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }, + }, &config.Config{}), "test-version") recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -110,6 +115,6 @@ func TestSettingHandler_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t * require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) require.Equal(t, 0, resp.Code) require.True(t, resp.Data.WeChatOAuthEnabled) - require.True(t, resp.Data.WeChatOAuthOpenEnabled) - require.False(t, resp.Data.WeChatOAuthMPEnabled) + require.False(t, resp.Data.WeChatOAuthOpenEnabled) + require.True(t, resp.Data.WeChatOAuthMPEnabled) } diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 1dddf77e..4d63cabc 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -111,6 +111,15 @@ const ( SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret" SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" + // WeChat Connect OAuth 登录设置 + SettingKeyWeChatConnectEnabled = "wechat_connect_enabled" + SettingKeyWeChatConnectAppID = "wechat_connect_app_id" + SettingKeyWeChatConnectAppSecret = "wechat_connect_app_secret" + SettingKeyWeChatConnectMode = "wechat_connect_mode" + SettingKeyWeChatConnectScopes = "wechat_connect_scopes" + SettingKeyWeChatConnectRedirectURL = "wechat_connect_redirect_url" + SettingKeyWeChatConnectFrontendRedirectURL = "wechat_connect_frontend_redirect_url" + // Generic OIDC OAuth 登录设置 SettingKeyOIDCConnectEnabled = "oidc_connect_enabled" SettingKeyOIDCConnectProviderName = "oidc_connect_provider_name" diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index 34462a3a..2d1f3f42 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -93,6 +93,11 @@ type UpdatePaymentConfigRequest struct { CancelRateLimitWindow *int `json:"cancel_rate_limit_window"` CancelRateLimitUnit *string `json:"cancel_rate_limit_unit"` CancelRateLimitMode *string `json:"cancel_rate_limit_window_mode"` + + VisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"` + VisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"` + VisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"` + VisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"` } // MethodLimits holds per-payment-type limits. @@ -319,6 +324,10 @@ func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req Upda SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow), SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit), SettingCancelWindowMode: derefStr(req.CancelRateLimitMode), + SettingPaymentVisibleMethodAlipaySource: derefStr(req.VisibleMethodAlipaySource), + SettingPaymentVisibleMethodWxpaySource: derefStr(req.VisibleMethodWxpaySource), + SettingPaymentVisibleMethodAlipayEnabled: formatBoolOrEmpty(req.VisibleMethodAlipayEnabled), + SettingPaymentVisibleMethodWxpayEnabled: formatBoolOrEmpty(req.VisibleMethodWxpayEnabled), } if req.EnabledTypes != nil { m[SettingEnabledPaymentTypes] = strings.Join(req.EnabledTypes, ",") diff --git a/backend/internal/service/payment_config_service_test.go b/backend/internal/service/payment_config_service_test.go index 10919058..d58ee234 100644 --- a/backend/internal/service/payment_config_service_test.go +++ b/backend/internal/service/payment_config_service_test.go @@ -366,7 +366,8 @@ func newPaymentConfigServiceTestClient(t *testing.T) *dbent.Client { } type paymentConfigSettingRepoStub struct { - values map[string]string + values map[string]string + updates map[string]string } func (s *paymentConfigSettingRepoStub) Get(context.Context, string) (*Setting, error) { @@ -383,10 +384,52 @@ func (s *paymentConfigSettingRepoStub) GetMultiple(_ context.Context, keys []str } return out, nil } -func (s *paymentConfigSettingRepoStub) SetMultiple(context.Context, map[string]string) error { +func (s *paymentConfigSettingRepoStub) SetMultiple(_ context.Context, values map[string]string) error { + s.updates = make(map[string]string, len(values)) + for key, value := range values { + s.updates[key] = value + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + } return nil } func (s *paymentConfigSettingRepoStub) GetAll(context.Context) (map[string]string, error) { return s.values, nil } func (s *paymentConfigSettingRepoStub) Delete(context.Context, string) error { return nil } + +func TestUpdatePaymentConfig_PersistsVisibleMethodRouting(t *testing.T) { + repo := &paymentConfigSettingRepoStub{values: map[string]string{}} + svc := &PaymentConfigService{settingRepo: repo} + + alipayEnabled := true + wxpayEnabled := false + err := svc.UpdatePaymentConfig(context.Background(), UpdatePaymentConfigRequest{ + VisibleMethodAlipayEnabled: &alipayEnabled, + VisibleMethodAlipaySource: paymentConfigStrPtr(VisibleMethodSourceEasyPayAlipay), + VisibleMethodWxpayEnabled: &wxpayEnabled, + VisibleMethodWxpaySource: paymentConfigStrPtr(VisibleMethodSourceOfficialWechat), + }) + if err != nil { + t.Fatalf("UpdatePaymentConfig returned error: %v", err) + } + + if repo.values[SettingPaymentVisibleMethodAlipayEnabled] != "true" { + t.Fatalf("alipay enabled = %q, want true", repo.values[SettingPaymentVisibleMethodAlipayEnabled]) + } + if repo.values[SettingPaymentVisibleMethodAlipaySource] != VisibleMethodSourceEasyPayAlipay { + t.Fatalf("alipay source = %q, want %q", repo.values[SettingPaymentVisibleMethodAlipaySource], VisibleMethodSourceEasyPayAlipay) + } + if repo.values[SettingPaymentVisibleMethodWxpayEnabled] != "false" { + t.Fatalf("wxpay enabled = %q, want false", repo.values[SettingPaymentVisibleMethodWxpayEnabled]) + } + if repo.values[SettingPaymentVisibleMethodWxpaySource] != VisibleMethodSourceOfficialWechat { + t.Fatalf("wxpay source = %q, want %q", repo.values[SettingPaymentVisibleMethodWxpaySource], VisibleMethodSourceOfficialWechat) + } +} + +func paymentConfigStrPtr(value string) *string { + return &value +} diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go index 354f3cd1..01d8c642 100644 --- a/backend/internal/service/payment_order.go +++ b/backend/internal/service/payment_order.go @@ -6,7 +6,6 @@ import ( "log/slog" "math" "net/url" - "os" "strconv" "strings" "time" @@ -512,16 +511,21 @@ func requiresWeChatJSAPICompatibleSelection(req CreateOrderRequest, sel *payment return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != "" } -func (s *PaymentService) getWeChatPaymentOAuthCredential(context.Context) (string, string, error) { - appID := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) - appSecret := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) - if appID == "" || appSecret == "" { +func (s *PaymentService) getWeChatPaymentOAuthCredential(ctx context.Context) (string, string, error) { + if s == nil || s.configService == nil || s.configService.settingRepo == nil { return "", "", infraerrors.ServiceUnavailable( "WECHAT_PAYMENT_MP_NOT_CONFIGURED", "wechat in-app payment requires a complete WeChat MP OAuth credential", ) } - return appID, appSecret, nil + cfg, err := (&SettingService{settingRepo: s.configService.settingRepo}).GetWeChatConnectOAuthConfig(ctx) + if err != nil || cfg.Mode != "mp" || strings.TrimSpace(cfg.AppID) == "" || strings.TrimSpace(cfg.AppSecret) == "" { + return "", "", infraerrors.ServiceUnavailable( + "WECHAT_PAYMENT_MP_NOT_CONFIGURED", + "wechat in-app payment requires a complete WeChat MP OAuth credential", + ) + } + return strings.TrimSpace(cfg.AppID), strings.TrimSpace(cfg.AppSecret), nil } func classifyCreatePaymentError(req CreateOrderRequest, providerKey string, err error) error { diff --git a/backend/internal/service/payment_order_jsapi_test.go b/backend/internal/service/payment_order_jsapi_test.go index 08492432..25f209af 100644 --- a/backend/internal/service/payment_order_jsapi_test.go +++ b/backend/internal/service/payment_order_jsapi_test.go @@ -60,10 +60,17 @@ func TestSelectCreateOrderInstancePrefersJSAPICompatibleWxpayInstance(t *testing } configService := &PaymentConfigService{ - entClient: client, + entClient: client, settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{ - SettingPaymentVisibleMethodWxpayEnabled: "true", - SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat, + SettingPaymentVisibleMethodWxpayEnabled: "true", + SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat, + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx-mp-app", + SettingKeyWeChatConnectAppSecret: "wechat-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", }}, encryptionKey: []byte(jsapiTestEncryptionKey), } @@ -77,9 +84,6 @@ func TestSelectCreateOrderInstancePrefersJSAPICompatibleWxpayInstance(t *testing configService: configService, } - t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app") - t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wechat-secret") - sel, err := svc.selectCreateOrderInstance(ctx, CreateOrderRequest{ PaymentType: payment.TypeWxpay, OpenID: "openid-123", diff --git a/backend/internal/service/payment_order_result_test.go b/backend/internal/service/payment_order_result_test.go index 0daa8213..16757323 100644 --- a/backend/internal/service/payment_order_result_test.go +++ b/backend/internal/service/payment_order_result_test.go @@ -91,10 +91,15 @@ func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) { } func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) { - t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx123456") - t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wechat-secret") - - svc := &PaymentService{} + svc := newWeChatPaymentOAuthTestService(map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx123456", + SettingKeyWeChatConnectAppSecret: "wechat-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }) resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{ Amount: 12.5, @@ -132,7 +137,7 @@ func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) { func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testing.T) { t.Parallel() - svc := &PaymentService{} + svc := newWeChatPaymentOAuthTestService(nil) resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{ Amount: 12.5, @@ -155,10 +160,15 @@ func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testin } func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) { - t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx123456") - t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wechat-secret") - - svc := &PaymentService{} + svc := newWeChatPaymentOAuthTestService(map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx123456", + SettingKeyWeChatConnectAppSecret: "wechat-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }) resp, err := svc.maybeBuildWeChatOAuthRequiredResponseForSelection(context.Background(), CreateOrderRequest{ Amount: 12.5, @@ -175,3 +185,11 @@ func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t t.Fatalf("expected nil response, got %+v", resp) } } + +func newWeChatPaymentOAuthTestService(values map[string]string) *PaymentService { + return &PaymentService{ + configService: &PaymentConfigService{ + settingRepo: &paymentConfigSettingRepoStub{values: values}, + }, + } +} diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 8c879d52..373c1aef 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -9,7 +9,6 @@ import ( "fmt" "log/slog" "net/url" - "os" "sort" "strconv" "strings" @@ -173,8 +172,43 @@ var ( const ( defaultAuthSourceBalance = 0 defaultAuthSourceConcurrency = 5 + defaultWeChatConnectMode = "open" + defaultWeChatConnectScopes = "snsapi_login" + defaultWeChatConnectFrontend = "/auth/wechat/callback" ) +func normalizeWeChatConnectModeSetting(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "mp": + return "mp" + default: + return "open" + } +} + +func defaultWeChatConnectScopeForMode(mode string) string { + if normalizeWeChatConnectModeSetting(mode) == "mp" { + return "snsapi_userinfo" + } + return defaultWeChatConnectScopes +} + +func normalizeWeChatConnectScopeSetting(raw, mode string) string { + switch normalizeWeChatConnectModeSetting(mode) { + case "mp": + switch strings.TrimSpace(raw) { + case "snsapi_base": + return "snsapi_base" + case "snsapi_userinfo": + return "snsapi_userinfo" + default: + return defaultWeChatConnectScopeForMode(mode) + } + default: + return defaultWeChatConnectScopes + } +} + // NewSettingService 创建系统设置服务实例 func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *SettingService { return &SettingService{ @@ -240,6 +274,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyCustomMenuItems, SettingKeyCustomEndpoints, SettingKeyLinuxDoConnectEnabled, + SettingKeyWeChatConnectEnabled, + SettingKeyWeChatConnectAppID, + SettingKeyWeChatConnectAppSecret, + SettingKeyWeChatConnectMode, + SettingKeyWeChatConnectScopes, + SettingKeyWeChatConnectRedirectURL, + SettingKeyWeChatConnectFrontendRedirectURL, SettingKeyBackendModeEnabled, SettingPaymentEnabled, SettingKeyOIDCConnectEnabled, @@ -274,9 +315,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings if oidcProviderName == "" { oidcProviderName = "OIDC" } - weChatOpenEnabled := isWeChatOAuthOpenConfigured() - weChatMPEnabled := isWeChatOAuthMPConfigured() - weChatEnabled := weChatOpenEnabled || weChatMPEnabled + weChatEnabled, weChatOpenEnabled, weChatMPEnabled := s.weChatOAuthCapabilitiesFromSettings(settings) // Password reset requires email verification to be enabled emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" @@ -431,6 +470,56 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any }, nil } +func DefaultWeChatConnectScopesForMode(mode string) string { + return defaultWeChatConnectScopeForMode(mode) +} + +func (s *SettingService) parseWeChatConnectOAuthConfig(settings map[string]string) (WeChatConnectOAuthConfig, error) { + cfg := WeChatConnectOAuthConfig{ + Enabled: settings[SettingKeyWeChatConnectEnabled] == "true", + AppID: strings.TrimSpace(settings[SettingKeyWeChatConnectAppID]), + AppSecret: strings.TrimSpace(settings[SettingKeyWeChatConnectAppSecret]), + Mode: normalizeWeChatConnectModeSetting(settings[SettingKeyWeChatConnectMode]), + Scopes: normalizeWeChatConnectScopeSetting(settings[SettingKeyWeChatConnectScopes], settings[SettingKeyWeChatConnectMode]), + RedirectURL: strings.TrimSpace(settings[SettingKeyWeChatConnectRedirectURL]), + FrontendRedirectURL: strings.TrimSpace(settings[SettingKeyWeChatConnectFrontendRedirectURL]), + } + if cfg.FrontendRedirectURL == "" { + cfg.FrontendRedirectURL = defaultWeChatConnectFrontend + } + + if !cfg.Enabled { + return WeChatConnectOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled") + } + if cfg.AppID == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth app id not configured") + } + if cfg.AppSecret == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth app secret not configured") + } + if cfg.RedirectURL == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured") + } + if cfg.FrontendRedirectURL == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth frontend redirect url not configured") + } + if err := config.ValidateAbsoluteHTTPURL(cfg.RedirectURL); err != nil { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url invalid") + } + if err := config.ValidateFrontendRedirectURL(cfg.FrontendRedirectURL); err != nil { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth frontend redirect url invalid") + } + return cfg, nil +} + +func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string]string) (bool, bool, bool) { + cfg, err := s.parseWeChatConnectOAuthConfig(settings) + if err != nil { + return false, false, false + } + return true, cfg.Mode == "open", cfg.Mode == "mp" +} + // filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON // array string, returning only items with visibility != "admin". func filterUserVisibleMenuItems(raw string) json.RawMessage { @@ -467,20 +556,6 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage { return result } -func isWeChatOAuthConfigured() bool { - return isWeChatOAuthOpenConfigured() || isWeChatOAuthMPConfigured() -} - -func isWeChatOAuthOpenConfigured() bool { - return strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) != "" && - strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) != "" -} - -func isWeChatOAuthMPConfigured() bool { - return strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) != "" && - strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) != "" -} - // safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]". func safeRawJSONArray(raw string) json.RawMessage { raw = strings.TrimSpace(raw) @@ -625,6 +700,15 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting } settings.PaymentVisibleMethodAlipaySource = alipaySource settings.PaymentVisibleMethodWxpaySource = wxpaySource + settings.WeChatConnectAppID = strings.TrimSpace(settings.WeChatConnectAppID) + settings.WeChatConnectAppSecret = strings.TrimSpace(settings.WeChatConnectAppSecret) + settings.WeChatConnectMode = normalizeWeChatConnectModeSetting(settings.WeChatConnectMode) + settings.WeChatConnectScopes = normalizeWeChatConnectScopeSetting(settings.WeChatConnectScopes, settings.WeChatConnectMode) + settings.WeChatConnectRedirectURL = strings.TrimSpace(settings.WeChatConnectRedirectURL) + settings.WeChatConnectFrontendRedirectURL = strings.TrimSpace(settings.WeChatConnectFrontendRedirectURL) + if settings.WeChatConnectFrontendRedirectURL == "" { + settings.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend + } updates := make(map[string]string) @@ -694,6 +778,17 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret } + // WeChat Connect OAuth 登录 + updates[SettingKeyWeChatConnectEnabled] = strconv.FormatBool(settings.WeChatConnectEnabled) + updates[SettingKeyWeChatConnectAppID] = settings.WeChatConnectAppID + updates[SettingKeyWeChatConnectMode] = settings.WeChatConnectMode + updates[SettingKeyWeChatConnectScopes] = settings.WeChatConnectScopes + updates[SettingKeyWeChatConnectRedirectURL] = settings.WeChatConnectRedirectURL + updates[SettingKeyWeChatConnectFrontendRedirectURL] = settings.WeChatConnectFrontendRedirectURL + if settings.WeChatConnectAppSecret != "" { + updates[SettingKeyWeChatConnectAppSecret] = settings.WeChatConnectAppSecret + } + // OEM设置 updates[SettingKeySiteName] = settings.SiteName updates[SettingKeySiteLogo] = settings.SiteLogo @@ -1200,6 +1295,10 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyTablePageSizeOptions: "[10,20,50,100]", SettingKeyCustomMenuItems: "[]", SettingKeyCustomEndpoints: "[]", + SettingKeyWeChatConnectEnabled: "false", + SettingKeyWeChatConnectMode: "open", + SettingKeyWeChatConnectScopes: "snsapi_login", + SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend, SettingKeyOIDCConnectEnabled: "false", SettingKeyOIDCConnectProviderName: "OIDC", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), @@ -1491,6 +1590,19 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != "" + // WeChat Connect 设置:完全以 DB 系统设置为准。 + result.WeChatConnectEnabled = settings[SettingKeyWeChatConnectEnabled] == "true" + result.WeChatConnectAppID = strings.TrimSpace(settings[SettingKeyWeChatConnectAppID]) + result.WeChatConnectAppSecret = strings.TrimSpace(settings[SettingKeyWeChatConnectAppSecret]) + result.WeChatConnectAppSecretConfigured = result.WeChatConnectAppSecret != "" + result.WeChatConnectMode = normalizeWeChatConnectModeSetting(settings[SettingKeyWeChatConnectMode]) + result.WeChatConnectScopes = normalizeWeChatConnectScopeSetting(settings[SettingKeyWeChatConnectScopes], settings[SettingKeyWeChatConnectMode]) + result.WeChatConnectRedirectURL = strings.TrimSpace(settings[SettingKeyWeChatConnectRedirectURL]) + result.WeChatConnectFrontendRedirectURL = strings.TrimSpace(settings[SettingKeyWeChatConnectFrontendRedirectURL]) + if result.WeChatConnectFrontendRedirectURL == "" { + result.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend + } + // Model fallback settings result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true" result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022") @@ -1972,6 +2084,26 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf return effective, nil } +// GetWeChatConnectOAuthConfig 返回用于登录的最终生效 WeChat Connect 配置。 +// +// WeChat Connect 已回归 DB 系统设置模型,不再回退到 config/env。 +func (s *SettingService) GetWeChatConnectOAuthConfig(ctx context.Context) (WeChatConnectOAuthConfig, error) { + keys := []string{ + SettingKeyWeChatConnectEnabled, + SettingKeyWeChatConnectAppID, + SettingKeyWeChatConnectAppSecret, + SettingKeyWeChatConnectMode, + SettingKeyWeChatConnectScopes, + SettingKeyWeChatConnectRedirectURL, + SettingKeyWeChatConnectFrontendRedirectURL, + } + settings, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return WeChatConnectOAuthConfig{}, fmt.Errorf("get wechat connect settings: %w", err) + } + return s.parseWeChatConnectOAuthConfig(settings) +} + // GetOverloadCooldownSettings 获取529过载冷却配置 func (s *SettingService) GetOverloadCooldownSettings(ctx context.Context) (*OverloadCooldownSettings, error) { value, err := s.settingRepo.GetValue(ctx, SettingKeyOverloadCooldownSettings) diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go index 4cfa9f0c..ffc069dc 100644 --- a/backend/internal/service/setting_service_public_test.go +++ b/backend/internal/service/setting_service_public_test.go @@ -92,16 +92,21 @@ func TestSettingService_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t } func TestSettingService_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_MP_APP_ID", "") - t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "") - - svc := NewSettingService(&settingPublicRepoStub{}, &config.Config{}) + svc := NewSettingService(&settingPublicRepoStub{ + values: map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx-mp-app", + SettingKeyWeChatConnectAppSecret: "wx-mp-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }, + }, &config.Config{}) settings, err := svc.GetPublicSettings(context.Background()) require.NoError(t, err) require.True(t, settings.WeChatOAuthEnabled) - require.True(t, settings.WeChatOAuthOpenEnabled) - require.False(t, settings.WeChatOAuthMPEnabled) + require.False(t, settings.WeChatOAuthOpenEnabled) + require.True(t, settings.WeChatOAuthMPEnabled) } diff --git a/backend/internal/service/setting_service_wechat_config_test.go b/backend/internal/service/setting_service_wechat_config_test.go new file mode 100644 index 00000000..2cb312cc --- /dev/null +++ b/backend/internal/service/setting_service_wechat_config_test.go @@ -0,0 +1,77 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type settingWeChatRepoStub struct { + values map[string]string +} + +func (s *settingWeChatRepoStub) Get(context.Context, string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *settingWeChatRepoStub) GetValue(_ context.Context, key string) (string, error) { + if value, ok := s.values[key]; ok { + return value, nil + } + return "", ErrSettingNotFound +} + +func (s *settingWeChatRepoStub) Set(context.Context, string, string) error { + panic("unexpected Set call") +} + +func (s *settingWeChatRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingWeChatRepoStub) SetMultiple(context.Context, map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *settingWeChatRepoStub) GetAll(context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *settingWeChatRepoStub) Delete(context.Context, string) error { + panic("unexpected Delete call") +} + +func TestSettingService_GetWeChatConnectOAuthConfig_UsesDatabaseOverrides(t *testing.T) { + repo := &settingWeChatRepoStub{ + values: map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx-db-app", + SettingKeyWeChatConnectAppSecret: "wx-db-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }, + } + svc := NewSettingService(repo, &config.Config{}) + + got, err := svc.GetWeChatConnectOAuthConfig(context.Background()) + require.NoError(t, err) + require.True(t, got.Enabled) + require.Equal(t, "wx-db-app", got.AppID) + require.Equal(t, "wx-db-secret", got.AppSecret) + require.Equal(t, "mp", got.Mode) + require.Equal(t, "snsapi_base", got.Scopes) + require.Equal(t, "https://api.example.com/api/v1/auth/oauth/wechat/callback", got.RedirectURL) + require.Equal(t, "/auth/wechat/callback", got.FrontendRedirectURL) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 1b859dbd..41229bfb 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -31,6 +31,16 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool LinuxDoConnectRedirectURL string + // WeChat Connect OAuth 登录 + WeChatConnectEnabled bool + WeChatConnectAppID string + WeChatConnectAppSecret string + WeChatConnectAppSecretConfigured bool + WeChatConnectMode string + WeChatConnectScopes string + WeChatConnectRedirectURL string + WeChatConnectFrontendRedirectURL string + // Generic OIDC OAuth 登录 OIDCConnectEnabled bool OIDCConnectProviderName string @@ -177,6 +187,16 @@ type PublicSettings struct { BalanceLowNotifyRechargeURL string } +type WeChatConnectOAuthConfig struct { + Enabled bool + AppID string + AppSecret string + Mode string + Scopes string + RedirectURL string + FrontendRedirectURL string +} + // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) type StreamTimeoutSettings struct { // Enabled 是否启用流超时处理 diff --git a/frontend/src/api/__tests__/settings.wechatConnect.spec.ts b/frontend/src/api/__tests__/settings.wechatConnect.spec.ts new file mode 100644 index 00000000..eccb7214 --- /dev/null +++ b/frontend/src/api/__tests__/settings.wechatConnect.spec.ts @@ -0,0 +1,21 @@ +import { describe, expect, it } from "vitest"; + +import { + defaultWeChatConnectScopesForMode, + normalizeWeChatConnectMode, +} from "@/api/admin/settings"; + +describe("admin settings wechat connect helpers", () => { + it("normalizes legacy or noisy mode values to the backend contract", () => { + expect(normalizeWeChatConnectMode("OPEN")).toBe("open"); + expect(normalizeWeChatConnectMode(" open_platform ")).toBe("open"); + expect(normalizeWeChatConnectMode("mp")).toBe("mp"); + expect(normalizeWeChatConnectMode("official_account")).toBe("mp"); + expect(normalizeWeChatConnectMode("unknown")).toBe("open"); + }); + + it("maps each mode to the backend default scopes", () => { + expect(defaultWeChatConnectScopesForMode("open")).toBe("snsapi_login"); + expect(defaultWeChatConnectScopesForMode("mp")).toBe("snsapi_userinfo"); + }); +}); diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 235bda7b..ed78a1bc 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -3,155 +3,240 @@ * Handles system settings management for administrators */ -import { apiClient } from '../client' -import type { CustomMenuItem, CustomEndpoint, NotifyEmailEntry } from '@/types' +import { apiClient } from "../client"; +import type { CustomMenuItem, CustomEndpoint, NotifyEmailEntry } from "@/types"; export interface DefaultSubscriptionSetting { - group_id: number - validity_days: number + group_id: number; + validity_days: number; } -export type AuthSourceType = 'email' | 'linuxdo' | 'oidc' | 'wechat' +export type AuthSourceType = "email" | "linuxdo" | "oidc" | "wechat"; export interface AuthSourceDefaultsValue { - balance: number - concurrency: number - subscriptions: DefaultSubscriptionSetting[] - grant_on_signup: boolean - grant_on_first_bind: boolean + balance: number; + concurrency: number; + subscriptions: DefaultSubscriptionSetting[]; + grant_on_signup: boolean; + grant_on_first_bind: boolean; } -export type AuthSourceDefaultsState = Record -export type PaymentVisibleMethod = 'alipay' | 'wxpay' +export type AuthSourceDefaultsState = Record< + AuthSourceType, + AuthSourceDefaultsValue +>; +export type PaymentVisibleMethod = "alipay" | "wxpay"; export type PaymentVisibleMethodSource = - | '' - | 'official_alipay' - | 'easypay_alipay' - | 'official_wxpay' - | 'easypay_wxpay' + | "" + | "official_alipay" + | "easypay_alipay" + | "official_wxpay" + | "easypay_wxpay"; +export type WeChatConnectMode = "open" | "mp"; export interface PaymentVisibleMethodSourceOption { - value: PaymentVisibleMethodSource - labelZh: string - labelEn: string + value: PaymentVisibleMethodSource; + labelZh: string; + labelEn: string; } -const AUTH_SOURCE_TYPES: AuthSourceType[] = ['email', 'linuxdo', 'oidc', 'wechat'] -const AUTH_SOURCE_DEFAULT_BALANCE = 0 -const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5 +export interface WeChatConnectModeOption { + value: WeChatConnectMode; + labelZh: string; + labelEn: string; +} + +const AUTH_SOURCE_TYPES: AuthSourceType[] = [ + "email", + "linuxdo", + "oidc", + "wechat", +]; +const AUTH_SOURCE_DEFAULT_BALANCE = 0; +const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5; const PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS: Record< PaymentVisibleMethod, PaymentVisibleMethodSourceOption[] > = { alipay: [ - { value: '', labelZh: '未配置', labelEn: 'Not configured' }, - { value: 'official_alipay', labelZh: '支付宝官方', labelEn: 'Official Alipay' }, - { value: 'easypay_alipay', labelZh: '易支付支付宝', labelEn: 'EasyPay Alipay' }, + { value: "", labelZh: "未配置", labelEn: "Not configured" }, + { + value: "official_alipay", + labelZh: "支付宝官方", + labelEn: "Official Alipay", + }, + { + value: "easypay_alipay", + labelZh: "易支付支付宝", + labelEn: "EasyPay Alipay", + }, ], wxpay: [ - { value: '', labelZh: '未配置', labelEn: 'Not configured' }, - { value: 'official_wxpay', labelZh: '微信官方', labelEn: 'Official WeChat Pay' }, - { value: 'easypay_wxpay', labelZh: '易支付微信', labelEn: 'EasyPay WeChat Pay' }, + { value: "", labelZh: "未配置", labelEn: "Not configured" }, + { + value: "official_wxpay", + labelZh: "微信官方", + labelEn: "Official WeChat Pay", + }, + { + value: "easypay_wxpay", + labelZh: "易支付微信", + labelEn: "EasyPay WeChat Pay", + }, ], -} +}; const PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES: Record< PaymentVisibleMethod, Record > = { alipay: { - official_alipay: 'official_alipay', - alipay: 'official_alipay', - alipay_direct: 'official_alipay', - official: 'official_alipay', - easypay_alipay: 'easypay_alipay', - easypay: 'easypay_alipay', + official_alipay: "official_alipay", + alipay: "official_alipay", + alipay_direct: "official_alipay", + official: "official_alipay", + easypay_alipay: "easypay_alipay", + easypay: "easypay_alipay", }, wxpay: { - official_wxpay: 'official_wxpay', - wxpay: 'official_wxpay', - wxpay_direct: 'official_wxpay', - wechat: 'official_wxpay', - official: 'official_wxpay', - easypay_wxpay: 'easypay_wxpay', - easypay: 'easypay_wxpay', + official_wxpay: "official_wxpay", + wxpay: "official_wxpay", + wxpay_direct: "official_wxpay", + wechat: "official_wxpay", + official: "official_wxpay", + easypay_wxpay: "easypay_wxpay", + easypay: "easypay_wxpay", }, -} +}; +const WECHAT_CONNECT_MODE_OPTIONS: WeChatConnectModeOption[] = [ + { value: "open", labelZh: "微信开放平台", labelEn: "WeChat Open Platform" }, + { + value: "mp", + labelZh: "微信公众号 / 小程序", + labelEn: "WeChat Official Account / Mini Program", + }, +]; +const WECHAT_CONNECT_MODE_ALIASES: Record = { + open: "open", + open_platform: "open", + official: "open", + wx_open: "open", + mp: "mp", + official_account: "mp", + wechat_mp: "mp", + mini_program: "mp", +}; export function normalizeDefaultSubscriptionSettings( - subscriptions: DefaultSubscriptionSetting[] | null | undefined + subscriptions: DefaultSubscriptionSetting[] | null | undefined, ): DefaultSubscriptionSetting[] { - if (!Array.isArray(subscriptions)) return [] + if (!Array.isArray(subscriptions)) return []; return subscriptions .filter((item) => item.group_id > 0 && item.validity_days > 0) .map((item) => ({ group_id: Math.floor(item.group_id), - validity_days: Math.min(36500, Math.max(1, Math.floor(item.validity_days))) - })) + validity_days: Math.min( + 36500, + Math.max(1, Math.floor(item.validity_days)), + ), + })); } export function buildAuthSourceDefaultsState( - settings: Partial + settings: Partial, ): AuthSourceDefaultsState { - const raw = settings as Record + const raw = settings as Record; return AUTH_SOURCE_TYPES.reduce((acc, source) => { - const subscriptions = raw[`auth_source_default_${source}_subscriptions`] + const subscriptions = raw[`auth_source_default_${source}_subscriptions`]; acc[source] = { - balance: Number(raw[`auth_source_default_${source}_balance`] ?? AUTH_SOURCE_DEFAULT_BALANCE), + balance: Number( + raw[`auth_source_default_${source}_balance`] ?? + AUTH_SOURCE_DEFAULT_BALANCE, + ), concurrency: Math.max( 1, - Number(raw[`auth_source_default_${source}_concurrency`] ?? AUTH_SOURCE_DEFAULT_CONCURRENCY) + Number( + raw[`auth_source_default_${source}_concurrency`] ?? + AUTH_SOURCE_DEFAULT_CONCURRENCY, + ), ), subscriptions: normalizeDefaultSubscriptionSettings( - Array.isArray(subscriptions) ? (subscriptions as DefaultSubscriptionSetting[]) : [] + Array.isArray(subscriptions) + ? (subscriptions as DefaultSubscriptionSetting[]) + : [], ), - grant_on_signup: raw[`auth_source_default_${source}_grant_on_signup`] !== false, - grant_on_first_bind: raw[`auth_source_default_${source}_grant_on_first_bind`] === true, - } - return acc - }, {} as AuthSourceDefaultsState) + grant_on_signup: + raw[`auth_source_default_${source}_grant_on_signup`] !== false, + grant_on_first_bind: + raw[`auth_source_default_${source}_grant_on_first_bind`] === true, + }; + return acc; + }, {} as AuthSourceDefaultsState); } export function appendAuthSourceDefaultsToUpdateRequest( payload: UpdateSettingsRequest, - authSourceDefaults: AuthSourceDefaultsState + authSourceDefaults: AuthSourceDefaultsState, ): UpdateSettingsRequest { - const target = payload as Record + const target = payload as Record; for (const source of AUTH_SOURCE_TYPES) { - const current = authSourceDefaults[source] - target[`auth_source_default_${source}_balance`] = Number(current.balance) || 0 + const current = authSourceDefaults[source]; + target[`auth_source_default_${source}_balance`] = + Number(current.balance) || 0; target[`auth_source_default_${source}_concurrency`] = Math.max( 1, - Math.floor(Number(current.concurrency) || AUTH_SOURCE_DEFAULT_CONCURRENCY) - ) - target[`auth_source_default_${source}_subscriptions`] = normalizeDefaultSubscriptionSettings( - current.subscriptions - ) - target[`auth_source_default_${source}_grant_on_signup`] = current.grant_on_signup - target[`auth_source_default_${source}_grant_on_first_bind`] = current.grant_on_first_bind + Math.floor( + Number(current.concurrency) || AUTH_SOURCE_DEFAULT_CONCURRENCY, + ), + ); + target[`auth_source_default_${source}_subscriptions`] = + normalizeDefaultSubscriptionSettings(current.subscriptions); + target[`auth_source_default_${source}_grant_on_signup`] = + current.grant_on_signup; + target[`auth_source_default_${source}_grant_on_first_bind`] = + current.grant_on_first_bind; } - return payload + return payload; } export function getPaymentVisibleMethodSourceOptions( - method: PaymentVisibleMethod + method: PaymentVisibleMethod, ): PaymentVisibleMethodSourceOption[] { - return PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS[method] + return PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS[method]; } export function normalizePaymentVisibleMethodSource( method: PaymentVisibleMethod, - source: unknown + source: unknown, ): PaymentVisibleMethodSource { - if (typeof source !== 'string') return '' + if (typeof source !== "string") return ""; - const normalized = source.trim().toLowerCase() - if (!normalized) return '' + const normalized = source.trim().toLowerCase(); + if (!normalized) return ""; - return PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES[method][normalized] ?? '' + return PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES[method][normalized] ?? ""; +} + +export function getWeChatConnectModeOptions(): WeChatConnectModeOption[] { + return WECHAT_CONNECT_MODE_OPTIONS; +} + +export function normalizeWeChatConnectMode(source: unknown): WeChatConnectMode { + if (typeof source !== "string") return "open"; + + const normalized = source.trim().toLowerCase(); + if (!normalized) return "open"; + + return WECHAT_CONNECT_MODE_ALIASES[normalized] ?? "open"; +} + +export function defaultWeChatConnectScopesForMode(mode: unknown): string { + return normalizeWeChatConnectMode(mode) === "mp" + ? "snsapi_userinfo" + : "snsapi_login"; } /** @@ -159,293 +244,309 @@ export function normalizePaymentVisibleMethodSource( */ export interface SystemSettings { // Registration settings - registration_enabled: boolean - email_verify_enabled: boolean - registration_email_suffix_whitelist: string[] - promo_code_enabled: boolean - password_reset_enabled: boolean - frontend_url: string - invitation_code_enabled: boolean - totp_enabled: boolean // TOTP 双因素认证 - totp_encryption_key_configured: boolean // TOTP 加密密钥是否已配置 + registration_enabled: boolean; + email_verify_enabled: boolean; + registration_email_suffix_whitelist: string[]; + promo_code_enabled: boolean; + password_reset_enabled: boolean; + frontend_url: string; + invitation_code_enabled: boolean; + totp_enabled: boolean; // TOTP 双因素认证 + totp_encryption_key_configured: boolean; // TOTP 加密密钥是否已配置 // Default settings - default_balance: number - default_concurrency: number - default_subscriptions: DefaultSubscriptionSetting[] - auth_source_default_email_balance?: number - auth_source_default_email_concurrency?: number - auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_email_grant_on_signup?: boolean - auth_source_default_email_grant_on_first_bind?: boolean - auth_source_default_linuxdo_balance?: number - auth_source_default_linuxdo_concurrency?: number - auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_linuxdo_grant_on_signup?: boolean - auth_source_default_linuxdo_grant_on_first_bind?: boolean - auth_source_default_oidc_balance?: number - auth_source_default_oidc_concurrency?: number - auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_oidc_grant_on_signup?: boolean - auth_source_default_oidc_grant_on_first_bind?: boolean - auth_source_default_wechat_balance?: number - auth_source_default_wechat_concurrency?: number - auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_wechat_grant_on_signup?: boolean - auth_source_default_wechat_grant_on_first_bind?: boolean - force_email_on_third_party_signup?: boolean + default_balance: number; + default_concurrency: number; + default_subscriptions: DefaultSubscriptionSetting[]; + auth_source_default_email_balance?: number; + auth_source_default_email_concurrency?: number; + auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_email_grant_on_signup?: boolean; + auth_source_default_email_grant_on_first_bind?: boolean; + auth_source_default_linuxdo_balance?: number; + auth_source_default_linuxdo_concurrency?: number; + auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_linuxdo_grant_on_signup?: boolean; + auth_source_default_linuxdo_grant_on_first_bind?: boolean; + auth_source_default_oidc_balance?: number; + auth_source_default_oidc_concurrency?: number; + auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_oidc_grant_on_signup?: boolean; + auth_source_default_oidc_grant_on_first_bind?: boolean; + auth_source_default_wechat_balance?: number; + auth_source_default_wechat_concurrency?: number; + auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_wechat_grant_on_signup?: boolean; + auth_source_default_wechat_grant_on_first_bind?: boolean; + force_email_on_third_party_signup?: boolean; // OEM settings - site_name: string - site_logo: string - site_subtitle: string - api_base_url: string - contact_info: string - doc_url: string - home_content: string - hide_ccs_import_button: boolean - table_default_page_size: number - table_page_size_options: number[] - backend_mode_enabled: boolean - custom_menu_items: CustomMenuItem[] - custom_endpoints: CustomEndpoint[] + site_name: string; + site_logo: string; + site_subtitle: string; + api_base_url: string; + contact_info: string; + doc_url: string; + home_content: string; + hide_ccs_import_button: boolean; + table_default_page_size: number; + table_page_size_options: number[]; + backend_mode_enabled: boolean; + custom_menu_items: CustomMenuItem[]; + custom_endpoints: CustomEndpoint[]; // SMTP settings - smtp_host: string - smtp_port: number - smtp_username: string - smtp_password_configured: boolean - smtp_from_email: string - smtp_from_name: string - smtp_use_tls: boolean + smtp_host: string; + smtp_port: number; + smtp_username: string; + smtp_password_configured: boolean; + smtp_from_email: string; + smtp_from_name: string; + smtp_use_tls: boolean; // Cloudflare Turnstile settings - turnstile_enabled: boolean - turnstile_site_key: string - turnstile_secret_key_configured: boolean + turnstile_enabled: boolean; + turnstile_site_key: string; + turnstile_secret_key_configured: boolean; // LinuxDo Connect OAuth settings - linuxdo_connect_enabled: boolean - linuxdo_connect_client_id: string - linuxdo_connect_client_secret_configured: boolean - linuxdo_connect_redirect_url: string + linuxdo_connect_enabled: boolean; + linuxdo_connect_client_id: string; + linuxdo_connect_client_secret_configured: boolean; + linuxdo_connect_redirect_url: string; + + // WeChat Connect OAuth settings + wechat_connect_enabled: boolean; + wechat_connect_app_id: string; + wechat_connect_app_secret_configured: boolean; + wechat_connect_mode: string; + wechat_connect_scopes: string; + wechat_connect_redirect_url: string; + wechat_connect_frontend_redirect_url: string; // Generic OIDC OAuth settings - oidc_connect_enabled: boolean - oidc_connect_provider_name: string - oidc_connect_client_id: string - oidc_connect_client_secret_configured: boolean - oidc_connect_issuer_url: string - oidc_connect_discovery_url: string - oidc_connect_authorize_url: string - oidc_connect_token_url: string - oidc_connect_userinfo_url: string - oidc_connect_jwks_url: string - oidc_connect_scopes: string - oidc_connect_redirect_url: string - oidc_connect_frontend_redirect_url: string - oidc_connect_token_auth_method: string - oidc_connect_use_pkce: boolean - oidc_connect_validate_id_token: boolean - oidc_connect_allowed_signing_algs: string - oidc_connect_clock_skew_seconds: number - oidc_connect_require_email_verified: boolean - oidc_connect_userinfo_email_path: string - oidc_connect_userinfo_id_path: string - oidc_connect_userinfo_username_path: string + oidc_connect_enabled: boolean; + oidc_connect_provider_name: string; + oidc_connect_client_id: string; + oidc_connect_client_secret_configured: boolean; + oidc_connect_issuer_url: string; + oidc_connect_discovery_url: string; + oidc_connect_authorize_url: string; + oidc_connect_token_url: string; + oidc_connect_userinfo_url: string; + oidc_connect_jwks_url: string; + oidc_connect_scopes: string; + oidc_connect_redirect_url: string; + oidc_connect_frontend_redirect_url: string; + oidc_connect_token_auth_method: string; + oidc_connect_use_pkce: boolean; + oidc_connect_validate_id_token: boolean; + oidc_connect_allowed_signing_algs: string; + oidc_connect_clock_skew_seconds: number; + oidc_connect_require_email_verified: boolean; + oidc_connect_userinfo_email_path: string; + oidc_connect_userinfo_id_path: string; + oidc_connect_userinfo_username_path: string; // Model fallback configuration - enable_model_fallback: boolean - fallback_model_anthropic: string - fallback_model_openai: string - fallback_model_gemini: string - fallback_model_antigravity: string + enable_model_fallback: boolean; + fallback_model_anthropic: string; + fallback_model_openai: string; + fallback_model_gemini: string; + fallback_model_antigravity: string; // Identity patch configuration (Claude -> Gemini) - enable_identity_patch: boolean - identity_patch_prompt: string + enable_identity_patch: boolean; + identity_patch_prompt: string; // Ops Monitoring (vNext) - ops_monitoring_enabled: boolean - ops_realtime_monitoring_enabled: boolean - ops_query_mode_default: 'auto' | 'raw' | 'preagg' | string - ops_metrics_interval_seconds: number + ops_monitoring_enabled: boolean; + ops_realtime_monitoring_enabled: boolean; + ops_query_mode_default: "auto" | "raw" | "preagg" | string; + ops_metrics_interval_seconds: number; // Claude Code version check - min_claude_code_version: string - max_claude_code_version: string + min_claude_code_version: string; + max_claude_code_version: string; // 分组隔离 - allow_ungrouped_key_scheduling: boolean + allow_ungrouped_key_scheduling: boolean; // Gateway forwarding behavior - enable_fingerprint_unification: boolean - enable_metadata_passthrough: boolean - enable_cch_signing: boolean - web_search_emulation_enabled?: boolean + enable_fingerprint_unification: boolean; + enable_metadata_passthrough: boolean; + enable_cch_signing: boolean; + web_search_emulation_enabled?: boolean; // Payment configuration - payment_enabled: boolean - payment_min_amount: number - payment_max_amount: number - payment_daily_limit: number - payment_order_timeout_minutes: number - payment_max_pending_orders: number - payment_enabled_types: string[] - payment_balance_disabled: boolean - payment_balance_recharge_multiplier: number - payment_recharge_fee_rate: number - payment_load_balance_strategy: string - payment_product_name_prefix: string - payment_product_name_suffix: string - payment_help_image_url: string - payment_help_text: string - payment_cancel_rate_limit_enabled: boolean - payment_cancel_rate_limit_max: number - payment_cancel_rate_limit_window: number - payment_cancel_rate_limit_unit: string - payment_cancel_rate_limit_window_mode: string - payment_visible_method_alipay_source?: string - payment_visible_method_wxpay_source?: string - payment_visible_method_alipay_enabled?: boolean - payment_visible_method_wxpay_enabled?: boolean - openai_advanced_scheduler_enabled?: boolean + payment_enabled: boolean; + payment_min_amount: number; + payment_max_amount: number; + payment_daily_limit: number; + payment_order_timeout_minutes: number; + payment_max_pending_orders: number; + payment_enabled_types: string[]; + payment_balance_disabled: boolean; + payment_balance_recharge_multiplier: number; + payment_recharge_fee_rate: number; + payment_load_balance_strategy: string; + payment_product_name_prefix: string; + payment_product_name_suffix: string; + payment_help_image_url: string; + payment_help_text: string; + payment_cancel_rate_limit_enabled: boolean; + payment_cancel_rate_limit_max: number; + payment_cancel_rate_limit_window: number; + payment_cancel_rate_limit_unit: string; + payment_cancel_rate_limit_window_mode: string; + payment_visible_method_alipay_source?: string; + payment_visible_method_wxpay_source?: string; + payment_visible_method_alipay_enabled?: boolean; + payment_visible_method_wxpay_enabled?: boolean; + openai_advanced_scheduler_enabled?: boolean; // Balance & quota notification - balance_low_notify_enabled: boolean - balance_low_notify_threshold: number - balance_low_notify_recharge_url: string - account_quota_notify_enabled: boolean - account_quota_notify_emails: NotifyEmailEntry[] + balance_low_notify_enabled: boolean; + balance_low_notify_threshold: number; + balance_low_notify_recharge_url: string; + account_quota_notify_enabled: boolean; + account_quota_notify_emails: NotifyEmailEntry[]; } export interface UpdateSettingsRequest { - registration_enabled?: boolean - email_verify_enabled?: boolean - registration_email_suffix_whitelist?: string[] - promo_code_enabled?: boolean - password_reset_enabled?: boolean - frontend_url?: string - invitation_code_enabled?: boolean - totp_enabled?: boolean // TOTP 双因素认证 - default_balance?: number - default_concurrency?: number - default_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_email_balance?: number - auth_source_default_email_concurrency?: number - auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_email_grant_on_signup?: boolean - auth_source_default_email_grant_on_first_bind?: boolean - auth_source_default_linuxdo_balance?: number - auth_source_default_linuxdo_concurrency?: number - auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_linuxdo_grant_on_signup?: boolean - auth_source_default_linuxdo_grant_on_first_bind?: boolean - auth_source_default_oidc_balance?: number - auth_source_default_oidc_concurrency?: number - auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_oidc_grant_on_signup?: boolean - auth_source_default_oidc_grant_on_first_bind?: boolean - auth_source_default_wechat_balance?: number - auth_source_default_wechat_concurrency?: number - auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_wechat_grant_on_signup?: boolean - auth_source_default_wechat_grant_on_first_bind?: boolean - force_email_on_third_party_signup?: boolean - site_name?: string - site_logo?: string - site_subtitle?: string - api_base_url?: string - contact_info?: string - doc_url?: string - home_content?: string - hide_ccs_import_button?: boolean - table_default_page_size?: number - table_page_size_options?: number[] - backend_mode_enabled?: boolean - custom_menu_items?: CustomMenuItem[] - custom_endpoints?: CustomEndpoint[] - smtp_host?: string - smtp_port?: number - smtp_username?: string - smtp_password?: string - smtp_from_email?: string - smtp_from_name?: string - smtp_use_tls?: boolean - turnstile_enabled?: boolean - turnstile_site_key?: string - turnstile_secret_key?: string - linuxdo_connect_enabled?: boolean - linuxdo_connect_client_id?: string - linuxdo_connect_client_secret?: string - linuxdo_connect_redirect_url?: string - oidc_connect_enabled?: boolean - oidc_connect_provider_name?: string - oidc_connect_client_id?: string - oidc_connect_client_secret?: string - oidc_connect_issuer_url?: string - oidc_connect_discovery_url?: string - oidc_connect_authorize_url?: string - oidc_connect_token_url?: string - oidc_connect_userinfo_url?: string - oidc_connect_jwks_url?: string - oidc_connect_scopes?: string - oidc_connect_redirect_url?: string - oidc_connect_frontend_redirect_url?: string - oidc_connect_token_auth_method?: string - oidc_connect_use_pkce?: boolean - oidc_connect_validate_id_token?: boolean - oidc_connect_allowed_signing_algs?: string - oidc_connect_clock_skew_seconds?: number - oidc_connect_require_email_verified?: boolean - oidc_connect_userinfo_email_path?: string - oidc_connect_userinfo_id_path?: string - oidc_connect_userinfo_username_path?: string - enable_model_fallback?: boolean - fallback_model_anthropic?: string - fallback_model_openai?: string - fallback_model_gemini?: string - fallback_model_antigravity?: string - enable_identity_patch?: boolean - identity_patch_prompt?: string - ops_monitoring_enabled?: boolean - ops_realtime_monitoring_enabled?: boolean - ops_query_mode_default?: 'auto' | 'raw' | 'preagg' | string - ops_metrics_interval_seconds?: number - min_claude_code_version?: string - max_claude_code_version?: string - allow_ungrouped_key_scheduling?: boolean - enable_fingerprint_unification?: boolean - enable_metadata_passthrough?: boolean - enable_cch_signing?: boolean + registration_enabled?: boolean; + email_verify_enabled?: boolean; + registration_email_suffix_whitelist?: string[]; + promo_code_enabled?: boolean; + password_reset_enabled?: boolean; + frontend_url?: string; + invitation_code_enabled?: boolean; + totp_enabled?: boolean; // TOTP 双因素认证 + default_balance?: number; + default_concurrency?: number; + default_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_email_balance?: number; + auth_source_default_email_concurrency?: number; + auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_email_grant_on_signup?: boolean; + auth_source_default_email_grant_on_first_bind?: boolean; + auth_source_default_linuxdo_balance?: number; + auth_source_default_linuxdo_concurrency?: number; + auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_linuxdo_grant_on_signup?: boolean; + auth_source_default_linuxdo_grant_on_first_bind?: boolean; + auth_source_default_oidc_balance?: number; + auth_source_default_oidc_concurrency?: number; + auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_oidc_grant_on_signup?: boolean; + auth_source_default_oidc_grant_on_first_bind?: boolean; + auth_source_default_wechat_balance?: number; + auth_source_default_wechat_concurrency?: number; + auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_wechat_grant_on_signup?: boolean; + auth_source_default_wechat_grant_on_first_bind?: boolean; + force_email_on_third_party_signup?: boolean; + site_name?: string; + site_logo?: string; + site_subtitle?: string; + api_base_url?: string; + contact_info?: string; + doc_url?: string; + home_content?: string; + hide_ccs_import_button?: boolean; + table_default_page_size?: number; + table_page_size_options?: number[]; + backend_mode_enabled?: boolean; + custom_menu_items?: CustomMenuItem[]; + custom_endpoints?: CustomEndpoint[]; + smtp_host?: string; + smtp_port?: number; + smtp_username?: string; + smtp_password?: string; + smtp_from_email?: string; + smtp_from_name?: string; + smtp_use_tls?: boolean; + turnstile_enabled?: boolean; + turnstile_site_key?: string; + turnstile_secret_key?: string; + linuxdo_connect_enabled?: boolean; + linuxdo_connect_client_id?: string; + linuxdo_connect_client_secret?: string; + linuxdo_connect_redirect_url?: string; + wechat_connect_enabled?: boolean; + wechat_connect_app_id?: string; + wechat_connect_app_secret?: string; + wechat_connect_mode?: string; + wechat_connect_scopes?: string; + wechat_connect_redirect_url?: string; + wechat_connect_frontend_redirect_url?: string; + oidc_connect_enabled?: boolean; + oidc_connect_provider_name?: string; + oidc_connect_client_id?: string; + oidc_connect_client_secret?: string; + oidc_connect_issuer_url?: string; + oidc_connect_discovery_url?: string; + oidc_connect_authorize_url?: string; + oidc_connect_token_url?: string; + oidc_connect_userinfo_url?: string; + oidc_connect_jwks_url?: string; + oidc_connect_scopes?: string; + oidc_connect_redirect_url?: string; + oidc_connect_frontend_redirect_url?: string; + oidc_connect_token_auth_method?: string; + oidc_connect_use_pkce?: boolean; + oidc_connect_validate_id_token?: boolean; + oidc_connect_allowed_signing_algs?: string; + oidc_connect_clock_skew_seconds?: number; + oidc_connect_require_email_verified?: boolean; + oidc_connect_userinfo_email_path?: string; + oidc_connect_userinfo_id_path?: string; + oidc_connect_userinfo_username_path?: string; + enable_model_fallback?: boolean; + fallback_model_anthropic?: string; + fallback_model_openai?: string; + fallback_model_gemini?: string; + fallback_model_antigravity?: string; + enable_identity_patch?: boolean; + identity_patch_prompt?: string; + ops_monitoring_enabled?: boolean; + ops_realtime_monitoring_enabled?: boolean; + ops_query_mode_default?: "auto" | "raw" | "preagg" | string; + ops_metrics_interval_seconds?: number; + min_claude_code_version?: string; + max_claude_code_version?: string; + allow_ungrouped_key_scheduling?: boolean; + enable_fingerprint_unification?: boolean; + enable_metadata_passthrough?: boolean; + enable_cch_signing?: boolean; // Payment configuration - payment_enabled?: boolean - payment_min_amount?: number - payment_max_amount?: number - payment_daily_limit?: number - payment_order_timeout_minutes?: number - payment_max_pending_orders?: number - payment_enabled_types?: string[] - payment_balance_disabled?: boolean - payment_balance_recharge_multiplier?: number - payment_recharge_fee_rate?: number - payment_load_balance_strategy?: string - payment_product_name_prefix?: string - payment_product_name_suffix?: string - payment_help_image_url?: string - payment_help_text?: string - payment_cancel_rate_limit_enabled?: boolean - payment_cancel_rate_limit_max?: number - payment_cancel_rate_limit_window?: number - payment_cancel_rate_limit_unit?: string - payment_cancel_rate_limit_window_mode?: string - payment_visible_method_alipay_source?: string - payment_visible_method_wxpay_source?: string - payment_visible_method_alipay_enabled?: boolean - payment_visible_method_wxpay_enabled?: boolean - openai_advanced_scheduler_enabled?: boolean + payment_enabled?: boolean; + payment_min_amount?: number; + payment_max_amount?: number; + payment_daily_limit?: number; + payment_order_timeout_minutes?: number; + payment_max_pending_orders?: number; + payment_enabled_types?: string[]; + payment_balance_disabled?: boolean; + payment_balance_recharge_multiplier?: number; + payment_recharge_fee_rate?: number; + payment_load_balance_strategy?: string; + payment_product_name_prefix?: string; + payment_product_name_suffix?: string; + payment_help_image_url?: string; + payment_help_text?: string; + payment_cancel_rate_limit_enabled?: boolean; + payment_cancel_rate_limit_max?: number; + payment_cancel_rate_limit_window?: number; + payment_cancel_rate_limit_unit?: string; + payment_cancel_rate_limit_window_mode?: string; + payment_visible_method_alipay_source?: string; + payment_visible_method_wxpay_source?: string; + payment_visible_method_alipay_enabled?: boolean; + payment_visible_method_wxpay_enabled?: boolean; + openai_advanced_scheduler_enabled?: boolean; // Balance & quota notification - balance_low_notify_enabled?: boolean - balance_low_notify_threshold?: number - balance_low_notify_recharge_url?: string - account_quota_notify_enabled?: boolean - account_quota_notify_emails?: NotifyEmailEntry[] + balance_low_notify_enabled?: boolean; + balance_low_notify_threshold?: number; + balance_low_notify_recharge_url?: string; + account_quota_notify_enabled?: boolean; + account_quota_notify_emails?: NotifyEmailEntry[]; } /** @@ -453,8 +554,8 @@ export interface UpdateSettingsRequest { * @returns System settings */ export async function getSettings(): Promise { - const { data } = await apiClient.get('/admin/settings') - return data + const { data } = await apiClient.get("/admin/settings"); + return data; } /** @@ -462,20 +563,25 @@ export async function getSettings(): Promise { * @param settings - Partial settings to update * @returns Updated settings */ -export async function updateSettings(settings: UpdateSettingsRequest): Promise { - const { data } = await apiClient.put('/admin/settings', settings) - return data +export async function updateSettings( + settings: UpdateSettingsRequest, +): Promise { + const { data } = await apiClient.put( + "/admin/settings", + settings, + ); + return data; } /** * Test SMTP connection request */ export interface TestSmtpRequest { - smtp_host: string - smtp_port: number - smtp_username: string - smtp_password: string - smtp_use_tls: boolean + smtp_host: string; + smtp_port: number; + smtp_username: string; + smtp_password: string; + smtp_use_tls: boolean; } /** @@ -483,23 +589,28 @@ export interface TestSmtpRequest { * @param config - SMTP configuration to test * @returns Test result message */ -export async function testSmtpConnection(config: TestSmtpRequest): Promise<{ message: string }> { - const { data } = await apiClient.post<{ message: string }>('/admin/settings/test-smtp', config) - return data +export async function testSmtpConnection( + config: TestSmtpRequest, +): Promise<{ message: string }> { + const { data } = await apiClient.post<{ message: string }>( + "/admin/settings/test-smtp", + config, + ); + return data; } /** * Send test email request */ export interface SendTestEmailRequest { - email: string - smtp_host: string - smtp_port: number - smtp_username: string - smtp_password: string - smtp_from_email: string - smtp_from_name: string - smtp_use_tls: boolean + email: string; + smtp_host: string; + smtp_port: number; + smtp_username: string; + smtp_password: string; + smtp_from_email: string; + smtp_from_name: string; + smtp_use_tls: boolean; } /** @@ -507,20 +618,22 @@ export interface SendTestEmailRequest { * @param request - Email address and SMTP config * @returns Test result message */ -export async function sendTestEmail(request: SendTestEmailRequest): Promise<{ message: string }> { +export async function sendTestEmail( + request: SendTestEmailRequest, +): Promise<{ message: string }> { const { data } = await apiClient.post<{ message: string }>( - '/admin/settings/send-test-email', - request - ) - return data + "/admin/settings/send-test-email", + request, + ); + return data; } /** * Admin API Key status response */ export interface AdminApiKeyStatus { - exists: boolean - masked_key: string + exists: boolean; + masked_key: string; } /** @@ -528,8 +641,10 @@ export interface AdminApiKeyStatus { * @returns Status indicating if key exists and masked version */ export async function getAdminApiKey(): Promise { - const { data } = await apiClient.get('/admin/settings/admin-api-key') - return data + const { data } = await apiClient.get( + "/admin/settings/admin-api-key", + ); + return data; } /** @@ -537,8 +652,10 @@ export async function getAdminApiKey(): Promise { * @returns The new full API key (only shown once) */ export async function regenerateAdminApiKey(): Promise<{ key: string }> { - const { data } = await apiClient.post<{ key: string }>('/admin/settings/admin-api-key/regenerate') - return data + const { data } = await apiClient.post<{ key: string }>( + "/admin/settings/admin-api-key/regenerate", + ); + return data; } /** @@ -546,8 +663,10 @@ export async function regenerateAdminApiKey(): Promise<{ key: string }> { * @returns Success message */ export async function deleteAdminApiKey(): Promise<{ message: string }> { - const { data } = await apiClient.delete<{ message: string }>('/admin/settings/admin-api-key') - return data + const { data } = await apiClient.delete<{ message: string }>( + "/admin/settings/admin-api-key", + ); + return data; } // ==================== Overload Cooldown Settings ==================== @@ -556,23 +675,25 @@ export async function deleteAdminApiKey(): Promise<{ message: string }> { * Overload cooldown settings interface (529 handling) */ export interface OverloadCooldownSettings { - enabled: boolean - cooldown_minutes: number + enabled: boolean; + cooldown_minutes: number; } export async function getOverloadCooldownSettings(): Promise { - const { data } = await apiClient.get('/admin/settings/overload-cooldown') - return data + const { data } = await apiClient.get( + "/admin/settings/overload-cooldown", + ); + return data; } export async function updateOverloadCooldownSettings( - settings: OverloadCooldownSettings + settings: OverloadCooldownSettings, ): Promise { const { data } = await apiClient.put( - '/admin/settings/overload-cooldown', - settings - ) - return data + "/admin/settings/overload-cooldown", + settings, + ); + return data; } // ==================== Stream Timeout Settings ==================== @@ -581,11 +702,11 @@ export async function updateOverloadCooldownSettings( * Stream timeout settings interface */ export interface StreamTimeoutSettings { - enabled: boolean - action: 'temp_unsched' | 'error' | 'none' - temp_unsched_minutes: number - threshold_count: number - threshold_window_minutes: number + enabled: boolean; + action: "temp_unsched" | "error" | "none"; + temp_unsched_minutes: number; + threshold_count: number; + threshold_window_minutes: number; } /** @@ -593,8 +714,10 @@ export interface StreamTimeoutSettings { * @returns Stream timeout settings */ export async function getStreamTimeoutSettings(): Promise { - const { data } = await apiClient.get('/admin/settings/stream-timeout') - return data + const { data } = await apiClient.get( + "/admin/settings/stream-timeout", + ); + return data; } /** @@ -603,13 +726,13 @@ export async function getStreamTimeoutSettings(): Promise * @returns Updated settings */ export async function updateStreamTimeoutSettings( - settings: StreamTimeoutSettings + settings: StreamTimeoutSettings, ): Promise { const { data } = await apiClient.put( - '/admin/settings/stream-timeout', - settings - ) - return data + "/admin/settings/stream-timeout", + settings, + ); + return data; } // ==================== Rectifier Settings ==================== @@ -618,11 +741,11 @@ export async function updateStreamTimeoutSettings( * Rectifier settings interface */ export interface RectifierSettings { - enabled: boolean - thinking_signature_enabled: boolean - thinking_budget_enabled: boolean - apikey_signature_enabled: boolean - apikey_signature_patterns: string[] + enabled: boolean; + thinking_signature_enabled: boolean; + thinking_budget_enabled: boolean; + apikey_signature_enabled: boolean; + apikey_signature_patterns: string[]; } /** @@ -630,8 +753,10 @@ export interface RectifierSettings { * @returns Rectifier settings */ export async function getRectifierSettings(): Promise { - const { data } = await apiClient.get('/admin/settings/rectifier') - return data + const { data } = await apiClient.get( + "/admin/settings/rectifier", + ); + return data; } /** @@ -640,13 +765,13 @@ export async function getRectifierSettings(): Promise { * @returns Updated settings */ export async function updateRectifierSettings( - settings: RectifierSettings + settings: RectifierSettings, ): Promise { const { data } = await apiClient.put( - '/admin/settings/rectifier', - settings - ) - return data + "/admin/settings/rectifier", + settings, + ); + return data; } // ==================== Beta Policy Settings ==================== @@ -655,20 +780,20 @@ export async function updateRectifierSettings( * Beta policy rule interface */ export interface BetaPolicyRule { - beta_token: string - action: 'pass' | 'filter' | 'block' - scope: 'all' | 'oauth' | 'apikey' | 'bedrock' - error_message?: string - model_whitelist?: string[] - fallback_action?: 'pass' | 'filter' | 'block' - fallback_error_message?: string + beta_token: string; + action: "pass" | "filter" | "block"; + scope: "all" | "oauth" | "apikey" | "bedrock"; + error_message?: string; + model_whitelist?: string[]; + fallback_action?: "pass" | "filter" | "block"; + fallback_error_message?: string; } /** * Beta policy settings interface */ export interface BetaPolicySettings { - rules: BetaPolicyRule[] + rules: BetaPolicyRule[]; } /** @@ -676,8 +801,10 @@ export interface BetaPolicySettings { * @returns Beta policy settings */ export async function getBetaPolicySettings(): Promise { - const { data } = await apiClient.get('/admin/settings/beta-policy') - return data + const { data } = await apiClient.get( + "/admin/settings/beta-policy", + ); + return data; } /** @@ -686,70 +813,73 @@ export async function getBetaPolicySettings(): Promise { * @returns Updated settings */ export async function updateBetaPolicySettings( - settings: BetaPolicySettings + settings: BetaPolicySettings, ): Promise { const { data } = await apiClient.put( - '/admin/settings/beta-policy', - settings - ) - return data + "/admin/settings/beta-policy", + settings, + ); + return data; } // --- Web Search Emulation Config --- export interface WebSearchProviderConfig { - type: 'brave' | 'tavily' - api_key: string - api_key_configured: boolean - quota_limit: number | null - subscribed_at: number | null - quota_used?: number - proxy_id: number | null - expires_at: number | null + type: "brave" | "tavily"; + api_key: string; + api_key_configured: boolean; + quota_limit: number | null; + subscribed_at: number | null; + quota_used?: number; + proxy_id: number | null; + expires_at: number | null; } export interface WebSearchEmulationConfig { - enabled: boolean - providers: WebSearchProviderConfig[] + enabled: boolean; + providers: WebSearchProviderConfig[]; } export interface WebSearchTestResult { - provider: string - results: { url: string; title: string; snippet: string; page_age?: string }[] - query: string + provider: string; + results: { url: string; title: string; snippet: string; page_age?: string }[]; + query: string; } export async function getWebSearchEmulationConfig(): Promise { const { data } = await apiClient.get( - '/admin/settings/web-search-emulation' - ) - return data + "/admin/settings/web-search-emulation", + ); + return data; } export async function updateWebSearchEmulationConfig( - config: WebSearchEmulationConfig + config: WebSearchEmulationConfig, ): Promise { const { data } = await apiClient.put( - '/admin/settings/web-search-emulation', - config - ) - return data + "/admin/settings/web-search-emulation", + config, + ); + return data; } export async function testWebSearchEmulation( - query: string + query: string, ): Promise { const { data } = await apiClient.post( - '/admin/settings/web-search-emulation/test', - { query } - ) - return data + "/admin/settings/web-search-emulation/test", + { query }, + ); + return data; } -export async function resetWebSearchUsage( - payload: { provider_type: string } -): Promise { - await apiClient.post('/admin/settings/web-search-emulation/reset-usage', payload) +export async function resetWebSearchUsage(payload: { + provider_type: string; +}): Promise { + await apiClient.post( + "/admin/settings/web-search-emulation/reset-usage", + payload, + ); } export const settingsAPI = { @@ -771,7 +901,7 @@ export const settingsAPI = { getWebSearchEmulationConfig, updateWebSearchEmulationConfig, testWebSearchEmulation, - resetWebSearchUsage -} + resetWebSearchUsage, +}; -export default settingsAPI +export default settingsAPI; diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index e56afe5f..9612d9f9 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -3,7 +3,9 @@
-
+
@@ -15,7 +17,10 @@ v-for="tab in settingsTabs" :key="tab.key" type="button" - :class="['settings-tab', activeTab === tab.key && 'settings-tab-active']" + :class="[ + 'settings-tab', + activeTab === tab.key && 'settings-tab-active', + ]" @click="activeTab = tab.key" > @@ -28,478 +33,4084 @@
- -
-
-

- {{ t('admin.settings.adminApiKey.title') }} -

-

- {{ t('admin.settings.adminApiKey.description') }} -

-
-
- + +
-
- -

- {{ t('admin.settings.adminApiKey.securityWarning') }} -

-
+

+ {{ t("admin.settings.adminApiKey.title") }} +

+

+ {{ t("admin.settings.adminApiKey.description") }} +

- - -
-
- {{ t('common.loading') }} -
- - -
- - {{ t('admin.settings.adminApiKey.notConfigured') }} - - -
- - -
-
-
- - - {{ adminApiKeyMasked }} - -
-
- - -
-
- - +
+
-

- {{ t('admin.settings.adminApiKey.keyWarning') }} -

-
- - {{ newAdminApiKey }} - - +
+ +

+ {{ t("admin.settings.adminApiKey.securityWarning") }} +

+
+
+ + +
+
+ {{ t("common.loading") }} +
+ + +
+ + {{ t("admin.settings.adminApiKey.notConfigured") }} + + +
+ + +
+
+
+ + + {{ adminApiKeyMasked }} + +
+
+ + +
+
+ + +
+

+ {{ t("admin.settings.adminApiKey.keyWarning") }} +

+
+ + {{ newAdminApiKey }} + + +
+

+ {{ t("admin.settings.adminApiKey.usage") }} +

-

- {{ t('admin.settings.adminApiKey.usage') }} -

-
+
- - -
-
-

- {{ t('admin.settings.overloadCooldown.title') }} -

-

- {{ t('admin.settings.overloadCooldown.description') }} -

-
-
-
-
- {{ t('common.loading') }} + +
+
+

+ {{ t("admin.settings.overloadCooldown.title") }} +

+

+ {{ t("admin.settings.overloadCooldown.description") }} +

- - -
-
- - -
-
-

- {{ t('admin.settings.streamTimeout.title') }} -

-

- {{ t('admin.settings.streamTimeout.description') }} -

-
-
- -
-
- {{ t('common.loading') }} -
- - -
-
- - -
-
-

- {{ t('admin.settings.rectifier.title') }} -

-

- {{ t('admin.settings.rectifier.description') }} -

-
-
- -
-
- {{ t('common.loading') }} -
- - +
-
- - - + +
- -
-
-
- -
-

- {{ t('admin.settings.emailTabDisabledTitle') }} -

-

- {{ t('admin.settings.emailTabDisabledHint') }} -

+ +
+
+
+ +
+

+ {{ t("admin.settings.emailTabDisabledTitle") }} +

+

+ {{ t("admin.settings.emailTabDisabledHint") }} +

+
-
- -
-
-
-

- {{ t('admin.settings.smtp.title') }} -

-

- {{ t('admin.settings.smtp.description') }} -

-
- -
-
-
-
- - -
-
- - -
-
- - -
-
- - -

- {{ - form.smtp_password_configured - ? t('admin.settings.smtp.passwordConfiguredHint') - : t('admin.settings.smtp.passwordHint') - }} -

-
-
- - -
-
- - -
-
- - + +
- -

- {{ t('admin.settings.smtp.useTlsHint') }} +

+ {{ t("admin.settings.smtp.title") }} +

+

+ {{ t("admin.settings.smtp.description") }}

- -
- -
-
- - -
-
-

- {{ t('admin.settings.testEmail.title') }} -

-

- {{ t('admin.settings.testEmail.description') }} -

-
-
-
-
- - -
-
-
- -
-
-

- {{ t('admin.settings.balanceNotify.title') }} -

-

- {{ t('admin.settings.balanceNotify.description') }} -

-
-
-
- - -
-
- -
- $ - -
-

{{ t('admin.settings.balanceNotify.thresholdHint') }}

-
-
- - -

{{ t('admin.settings.balanceNotify.rechargeUrlHint') }}

-
-
-
- - -
-
-

- {{ t('admin.settings.quotaNotify.title') }} -

-

- {{ t('admin.settings.quotaNotify.description') }} -

-
-
-
- - -
-
- -
-
-
+ + +
+
+

+ {{ t("admin.settings.testEmail.title") }} +

+

+ {{ t("admin.settings.testEmail.description") }} +

+
+
+
+
+ + +
+
-

{{ t('admin.settings.quotaNotify.emailsHint') }}

+
+
+ +
+
+

+ {{ t("admin.settings.balanceNotify.title") }} +

+

+ {{ t("admin.settings.balanceNotify.description") }} +

+
+
+
+ + +
+
+ +
+ $ + +
+

+ {{ t("admin.settings.balanceNotify.thresholdHint") }} +

+
+
+ + +

+ {{ t("admin.settings.balanceNotify.rechargeUrlHint") }} +

+
+
+
+ + +
+
+

+ {{ t("admin.settings.quotaNotify.title") }} +

+

+ {{ t("admin.settings.quotaNotify.description") }} +

+
+
+
+ + +
+
+ +
+
+ + + +
+ +
+

+ {{ t("admin.settings.quotaNotify.emailsHint") }} +

+
-
+
@@ -3070,8 +4644,17 @@
-
@@ -3104,22 +4691,32 @@ @close="showProviderDialog = false" @save="handleSaveProvider" /> - +
diff --git a/frontend/src/components/admin/account/AccountTestModal.vue b/frontend/src/components/admin/account/AccountTestModal.vue index 67409a7c..2e3db61b 100644 --- a/frontend/src/components/admin/account/AccountTestModal.vue +++ b/frontend/src/components/admin/account/AccountTestModal.vue @@ -55,12 +55,12 @@ />
-
+