feat(rpm): RPM 限流模块优化
P0: - rpm_override 嵌入 Auth Cache Snapshot,消除每请求 DB 查询 (snapshot v6→v7) - 429 RPM 响应返回 Retry-After 头(当前分钟剩余秒数) P1: - ClearAll 按钮直连 DELETE API,带 loading 防重复 - 新增 GET /admin/users/:id/rpm-status 管理员 RPM 用量查询端点 优化: - checkRPM 从级联互斥改为并行取最严,user.rpm_limit 作为全局硬上限始终生效 - Override/Group 变更后自动失效 auth cache - fail-open 语义不变,Redis 故障不阻塞业务
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -32,6 +33,7 @@ type AdminService interface {
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||
GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error)
|
||||
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
|
||||
// codeType is optional - pass empty string to return all types.
|
||||
// Also returns totalRecharged (sum of all positive balance top-ups).
|
||||
@@ -50,6 +52,8 @@ type AdminService interface {
|
||||
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
|
||||
ClearGroupRateMultipliers(ctx context.Context, groupID int64) error
|
||||
BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
|
||||
ClearGroupRPMOverrides(ctx context.Context, groupID int64) error
|
||||
BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error
|
||||
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
||||
|
||||
// API Key management (admin)
|
||||
@@ -114,6 +118,7 @@ type CreateUserInput struct {
|
||||
Notes string
|
||||
Balance float64
|
||||
Concurrency int
|
||||
RPMLimit int
|
||||
AllowedGroups []int64
|
||||
}
|
||||
|
||||
@@ -124,6 +129,7 @@ type UpdateUserInput struct {
|
||||
Notes *string
|
||||
Balance *float64 // 使用指针区分"未提供"和"设置为0"
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
RPMLimit *int // 使用指针区分"未提供"和"设置为0"
|
||||
Status string
|
||||
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
@@ -199,6 +205,8 @@ type CreateGroupInput struct {
|
||||
RequireOAuthOnly bool
|
||||
RequirePrivacySet bool
|
||||
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
|
||||
// RPMLimit 分组 RPM 上限(0 = 不限制)
|
||||
RPMLimit int
|
||||
// 从指定分组复制账号(创建分组后在同一事务内绑定)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -234,6 +242,8 @@ type UpdateGroupInput struct {
|
||||
RequireOAuthOnly *bool
|
||||
RequirePrivacySet *bool
|
||||
MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig
|
||||
// RPMLimit 分组 RPM 上限(0 = 不限制),nil 表示未提供不改动。
|
||||
RPMLimit *int
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -317,6 +327,22 @@ type ReplaceUserGroupResult struct {
|
||||
MigratedKeys int64 // 迁移的 Key 数量
|
||||
}
|
||||
|
||||
// UserRPMStatus describes a user's current per-minute RPM usage.
|
||||
type UserRPMStatus struct {
|
||||
UserRPMUsed int `json:"user_rpm_used"`
|
||||
UserRPMLimit int `json:"user_rpm_limit"`
|
||||
PerGroup []UserGroupRPMStatus `json:"per_group"`
|
||||
}
|
||||
|
||||
// UserGroupRPMStatus describes current per-minute RPM usage for one user/group pair.
|
||||
type UserGroupRPMStatus struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
Used int `json:"used"`
|
||||
Limit int `json:"limit"`
|
||||
Source string `json:"source"` // "group" | "override"
|
||||
}
|
||||
|
||||
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
|
||||
type BulkUpdateAccountsResult struct {
|
||||
Success int `json:"success"`
|
||||
@@ -463,6 +489,8 @@ const (
|
||||
proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36"
|
||||
)
|
||||
|
||||
var ErrRPMStatusUnavailable = infraerrors.New(http.StatusNotImplemented, "RPM_STATUS_UNAVAILABLE", "RPM cache not available")
|
||||
|
||||
// adminServiceImpl implements AdminService
|
||||
type adminServiceImpl struct {
|
||||
userRepo UserRepository
|
||||
@@ -472,6 +500,7 @@ type adminServiceImpl struct {
|
||||
apiKeyRepo APIKeyRepository
|
||||
redeemCodeRepo RedeemCodeRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
userRPMCache UserRPMCache
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
proxyLatencyCache ProxyLatencyCache
|
||||
@@ -496,6 +525,7 @@ func NewAdminService(
|
||||
apiKeyRepo APIKeyRepository,
|
||||
redeemCodeRepo RedeemCodeRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
userRPMCache UserRPMCache,
|
||||
billingCacheService *BillingCacheService,
|
||||
proxyProber ProxyExitInfoProber,
|
||||
proxyLatencyCache ProxyLatencyCache,
|
||||
@@ -514,6 +544,7 @@ func NewAdminService(
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
redeemCodeRepo: redeemCodeRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
userRPMCache: userRPMCache,
|
||||
billingCacheService: billingCacheService,
|
||||
proxyProber: proxyProber,
|
||||
proxyLatencyCache: proxyLatencyCache,
|
||||
@@ -617,6 +648,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
|
||||
Role: RoleUser, // Always create as regular user, never admin
|
||||
Balance: input.Balance,
|
||||
Concurrency: input.Concurrency,
|
||||
RPMLimit: input.RPMLimit,
|
||||
Status: StatusActive,
|
||||
AllowedGroups: input.AllowedGroups,
|
||||
}
|
||||
@@ -670,6 +702,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
oldConcurrency := user.Concurrency
|
||||
oldStatus := user.Status
|
||||
oldRole := user.Role
|
||||
oldRPMLimit := user.RPMLimit
|
||||
|
||||
if input.Email != "" {
|
||||
user.Email = input.Email
|
||||
@@ -695,6 +728,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
user.Concurrency = *input.Concurrency
|
||||
}
|
||||
|
||||
if input.RPMLimit != nil {
|
||||
user.RPMLimit = *input.RPMLimit
|
||||
}
|
||||
|
||||
if input.AllowedGroups != nil {
|
||||
user.AllowedGroups = *input.AllowedGroups
|
||||
}
|
||||
@@ -711,7 +748,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
}
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
|
||||
// RPMLimit 直接参与 billing_cache_service.checkRPM 的三级级联,
|
||||
// 不失效缓存会让修改在一个 L2 TTL 内失去效果。
|
||||
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole || user.RPMLimit != oldRPMLimit {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
|
||||
}
|
||||
}
|
||||
@@ -833,6 +872,81 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag
|
||||
return keys, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) {
|
||||
if s.userRPMCache == nil {
|
||||
return nil, ErrRPMStatusUnavailable
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userRPMUsed, err := s.userRPMCache.GetUserRPM(ctx, userID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to get user rpm: user_id=%d err=%v", userID, err)
|
||||
}
|
||||
|
||||
keys, _, err := s.GetUserAPIKeys(ctx, userID, 1, 1000, "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupIDSet := make(map[int64]struct{})
|
||||
for _, key := range keys {
|
||||
if key.GroupID != nil && *key.GroupID > 0 {
|
||||
groupIDSet[*key.GroupID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
groupIDs := make([]int64, 0, len(groupIDSet))
|
||||
for groupID := range groupIDSet {
|
||||
groupIDs = append(groupIDs, groupID)
|
||||
}
|
||||
sort.Slice(groupIDs, func(i, j int) bool { return groupIDs[i] < groupIDs[j] })
|
||||
|
||||
var perGroup []UserGroupRPMStatus
|
||||
for _, groupID := range groupIDs {
|
||||
used, getErr := s.userRPMCache.GetUserGroupRPM(ctx, userID, groupID)
|
||||
if getErr != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to get user group rpm: user_id=%d group_id=%d err=%v", userID, groupID, getErr)
|
||||
}
|
||||
|
||||
entry := UserGroupRPMStatus{
|
||||
GroupID: groupID,
|
||||
Used: used,
|
||||
}
|
||||
|
||||
if s.groupRepo != nil {
|
||||
if group, groupErr := s.groupRepo.GetByIDLite(ctx, groupID); groupErr == nil && group != nil {
|
||||
entry.GroupName = group.Name
|
||||
entry.Limit = group.RPMLimit
|
||||
entry.Source = "group"
|
||||
} else if groupErr != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to get group rpm status metadata: group_id=%d err=%v", groupID, groupErr)
|
||||
}
|
||||
}
|
||||
|
||||
if s.userGroupRateRepo != nil {
|
||||
override, overrideErr := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, userID, groupID)
|
||||
if overrideErr != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to get rpm override: user_id=%d group_id=%d err=%v", userID, groupID, overrideErr)
|
||||
} else if override != nil {
|
||||
entry.Limit = *override
|
||||
entry.Source = "override"
|
||||
}
|
||||
}
|
||||
|
||||
perGroup = append(perGroup, entry)
|
||||
}
|
||||
|
||||
return &UserRPMStatus{
|
||||
UserRPMUsed: userRPMUsed,
|
||||
UserRPMLimit: user.RPMLimit,
|
||||
PerGroup: perGroup,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
|
||||
// Return mock data for now
|
||||
return map[string]any{
|
||||
@@ -1314,6 +1428,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
RequirePrivacySet: input.RequirePrivacySet,
|
||||
DefaultMappedModel: input.DefaultMappedModel,
|
||||
MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig),
|
||||
RPMLimit: input.RPMLimit,
|
||||
}
|
||||
sanitizeGroupMessagesDispatchFields(group)
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
@@ -1548,12 +1663,19 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.MessagesDispatchModelConfig != nil {
|
||||
group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig)
|
||||
}
|
||||
if input.RPMLimit != nil {
|
||||
group.RPMLimit = *input.RPMLimit
|
||||
}
|
||||
sanitizeGroupMessagesDispatchFields(group)
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||
}
|
||||
|
||||
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
|
||||
if len(input.CopyAccountsFromGroupIDs) > 0 {
|
||||
// 去重源分组 IDs
|
||||
@@ -1622,9 +1744,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
}
|
||||
}
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
@@ -1700,6 +1819,39 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro
|
||||
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return nil
|
||||
}
|
||||
if err := s.userGroupRateRepo.ClearGroupRPMOverrides(ctx, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return nil
|
||||
}
|
||||
for _, e := range entries {
|
||||
if e.RPMOverride != nil && *e.RPMOverride < 0 {
|
||||
return infraerrors.BadRequest("INVALID_RPM_OVERRIDE", fmt.Sprintf("rpm_override must be >= 0 (user_id=%d)", e.UserID))
|
||||
}
|
||||
}
|
||||
if err := s.userGroupRateRepo.SyncGroupRPMOverrides(ctx, groupID, entries); err != nil {
|
||||
return err
|
||||
}
|
||||
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
||||
return s.groupRepo.UpdateSortOrders(ctx, updates)
|
||||
}
|
||||
|
||||
@@ -5,8 +5,10 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -21,6 +23,10 @@ type userGroupRateRepoStubForGroupRate struct {
|
||||
syncedGroupID int64
|
||||
syncedEntries []GroupRateMultiplierInput
|
||||
syncGroupErr error
|
||||
|
||||
rpmSyncedGroupID int64
|
||||
rpmSyncedEntries []GroupRPMOverrideInput
|
||||
rpmSyncErr error
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
|
||||
@@ -31,6 +37,10 @@ func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context,
|
||||
panic("unexpected GetByUserAndGroup call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
|
||||
panic("unexpected GetRPMOverrideByUserAndGroup call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
|
||||
if s.getByGroupIDErr != nil {
|
||||
return nil, s.getByGroupIDErr
|
||||
@@ -48,6 +58,16 @@ func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.C
|
||||
return s.syncGroupErr
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) SyncGroupRPMOverrides(_ context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
|
||||
s.rpmSyncedGroupID = groupID
|
||||
s.rpmSyncedEntries = entries
|
||||
return s.rpmSyncErr
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
|
||||
panic("unexpected ClearGroupRPMOverrides call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
|
||||
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
|
||||
return s.deleteByGroupErr
|
||||
@@ -62,8 +82,8 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{
|
||||
getByGroupIDData: map[int64][]UserGroupRateEntry{
|
||||
10: {
|
||||
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5},
|
||||
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8},
|
||||
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: ptrFloat(1.5)},
|
||||
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: ptrFloat(0.8)},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -74,9 +94,11 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
|
||||
require.Len(t, entries, 2)
|
||||
require.Equal(t, int64(1), entries[0].UserID)
|
||||
require.Equal(t, "alice", entries[0].UserName)
|
||||
require.Equal(t, 1.5, entries[0].RateMultiplier)
|
||||
require.NotNil(t, entries[0].RateMultiplier)
|
||||
require.Equal(t, 1.5, *entries[0].RateMultiplier)
|
||||
require.Equal(t, int64(2), entries[1].UserID)
|
||||
require.Equal(t, 0.8, entries[1].RateMultiplier)
|
||||
require.NotNil(t, entries[1].RateMultiplier)
|
||||
require.Equal(t, 0.8, *entries[1].RateMultiplier)
|
||||
})
|
||||
|
||||
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||
@@ -174,3 +196,30 @@ func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
|
||||
require.Contains(t, err.Error(), "sync failed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_BatchSetGroupRPMOverrides(t *testing.T) {
|
||||
t.Run("syncs entries to repo", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
override := 20
|
||||
entries := []GroupRPMOverrideInput{{UserID: 2, RPMOverride: &override}}
|
||||
|
||||
err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, entries)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), repo.rpmSyncedGroupID)
|
||||
require.Equal(t, entries, repo.rpmSyncedEntries)
|
||||
})
|
||||
|
||||
t.Run("rejects negative override as bad request", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
negative := -1
|
||||
|
||||
err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, []GroupRPMOverrideInput{
|
||||
{UserID: 2, RPMOverride: &negative},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, http.StatusBadRequest, infraerrors.Code(err))
|
||||
require.Zero(t, repo.rpmSyncedGroupID)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -266,6 +266,31 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
|
||||
require.Nil(t, repo.updated.ImagePrice4K)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "existing-group",
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
RPMLimit: 10,
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{
|
||||
groupRepo: repo,
|
||||
authCacheInvalidator: invalidator,
|
||||
}
|
||||
|
||||
rpmLimit := 60
|
||||
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
|
||||
RPMLimit: &rpmLimit,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.Equal(t, 60, repo.updated.RPMLimit)
|
||||
require.Equal(t, []int64{1}, invalidator.groupIDs, "分组 RPMLimit 写入 auth snapshot,变更后必须失效 API Key 认证缓存")
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
@@ -89,6 +89,10 @@ func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context,
|
||||
panic("unexpected GetByUserAndGroup call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
|
||||
panic("unexpected GetRPMOverrideByUserAndGroup call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
|
||||
panic("unexpected SyncUserGroupRates call")
|
||||
}
|
||||
@@ -101,6 +105,14 @@ func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.C
|
||||
panic("unexpected SyncGroupRateMultipliers call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) SyncGroupRPMOverrides(_ context.Context, _ int64, _ []GroupRPMOverrideInput) error {
|
||||
panic("unexpected SyncGroupRPMOverrides call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
|
||||
panic("unexpected ClearGroupRPMOverrides call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
|
||||
panic("unexpected DeleteByGroupID call")
|
||||
}
|
||||
|
||||
112
backend/internal/service/admin_service_rpm_status_test.go
Normal file
112
backend/internal/service/admin_service_rpm_status_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type rpmStatusUserRepoStub struct {
|
||||
UserRepository
|
||||
user *User
|
||||
}
|
||||
|
||||
func (s *rpmStatusUserRepoStub) GetByID(_ context.Context, _ int64) (*User, error) {
|
||||
return s.user, nil
|
||||
}
|
||||
|
||||
type rpmStatusAPIKeyRepoStub struct {
|
||||
APIKeyRepository
|
||||
keys []APIKey
|
||||
}
|
||||
|
||||
func (s *rpmStatusAPIKeyRepoStub) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
return s.keys, &pagination.PaginationResult{Total: int64(len(s.keys))}, nil
|
||||
}
|
||||
|
||||
type rpmStatusGroupRepoStub struct {
|
||||
GroupRepository
|
||||
groups map[int64]*Group
|
||||
}
|
||||
|
||||
func (s *rpmStatusGroupRepoStub) GetByIDLite(_ context.Context, id int64) (*Group, error) {
|
||||
return s.groups[id], nil
|
||||
}
|
||||
|
||||
type rpmStatusRateRepoStub struct {
|
||||
UserGroupRateRepository
|
||||
overrides map[int64]*int
|
||||
}
|
||||
|
||||
func (s *rpmStatusRateRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, groupID int64) (*int, error) {
|
||||
return s.overrides[groupID], nil
|
||||
}
|
||||
|
||||
type rpmStatusCacheStub struct {
|
||||
UserRPMCache
|
||||
userUsed int
|
||||
groupUsed map[int64]int
|
||||
}
|
||||
|
||||
func (s *rpmStatusCacheStub) IncrementUserGroupRPM(context.Context, int64, int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *rpmStatusCacheStub) IncrementUserRPM(context.Context, int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *rpmStatusCacheStub) GetUserGroupRPM(_ context.Context, _, groupID int64) (int, error) {
|
||||
return s.groupUsed[groupID], nil
|
||||
}
|
||||
|
||||
func (s *rpmStatusCacheStub) GetUserRPM(context.Context, int64) (int, error) {
|
||||
return s.userUsed, nil
|
||||
}
|
||||
|
||||
func TestAdminService_GetUserRPMStatus_AggregatesUserAndGroupLimits(t *testing.T) {
|
||||
groupOneID := int64(1)
|
||||
groupTwoID := int64(2)
|
||||
override := 7
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: &rpmStatusUserRepoStub{user: &User{
|
||||
ID: 42,
|
||||
RPMLimit: 20,
|
||||
}},
|
||||
apiKeyRepo: &rpmStatusAPIKeyRepoStub{keys: []APIKey{
|
||||
{ID: 100, UserID: 42, GroupID: &groupTwoID},
|
||||
{ID: 101, UserID: 42, GroupID: &groupOneID},
|
||||
{ID: 102, UserID: 42, GroupID: &groupTwoID},
|
||||
{ID: 103, UserID: 42},
|
||||
}},
|
||||
groupRepo: &rpmStatusGroupRepoStub{groups: map[int64]*Group{
|
||||
groupOneID: {ID: groupOneID, Name: "group-one", RPMLimit: 10},
|
||||
groupTwoID: {ID: groupTwoID, Name: "group-two", RPMLimit: 60},
|
||||
}},
|
||||
userGroupRateRepo: &rpmStatusRateRepoStub{overrides: map[int64]*int{
|
||||
groupTwoID: &override,
|
||||
}},
|
||||
userRPMCache: &rpmStatusCacheStub{
|
||||
userUsed: 5,
|
||||
groupUsed: map[int64]int{
|
||||
groupOneID: 3,
|
||||
groupTwoID: 4,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
status, err := svc.GetUserRPMStatus(context.Background(), 42)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &UserRPMStatus{
|
||||
UserRPMUsed: 5,
|
||||
UserRPMLimit: 20,
|
||||
PerGroup: []UserGroupRPMStatus{
|
||||
{GroupID: groupOneID, GroupName: "group-one", Used: 3, Limit: 10, Source: "group"},
|
||||
{GroupID: groupTwoID, GroupName: "group-two", Used: 4, Limit: 7, Source: "override"},
|
||||
},
|
||||
}, status)
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// rpmUserRepoStub 复用 admin_service_update_balance_test.go 的基础 stub 结构,
|
||||
// 只在 Update 时把入参克隆一份,便于断言修改后的 RPMLimit。
|
||||
type rpmUserRepoStub struct {
|
||||
*userRepoStub
|
||||
lastUpdated *User
|
||||
}
|
||||
|
||||
func (s *rpmUserRepoStub) Update(_ context.Context, user *User) error {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *user
|
||||
s.lastUpdated = &clone
|
||||
if s.userRepoStub != nil {
|
||||
s.userRepoStub.user = &clone
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateUser_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
|
||||
base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10}}
|
||||
repo := &rpmUserRepoStub{userRepoStub: base}
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: repo,
|
||||
redeemCodeRepo: &redeemRepoStub{},
|
||||
authCacheInvalidator: invalidator,
|
||||
}
|
||||
|
||||
newRPM := 60
|
||||
updated, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
|
||||
RPMLimit: &newRPM,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, 60, updated.RPMLimit)
|
||||
require.Equal(t, []int64{42}, invalidator.userIDs, "仅修改 RPMLimit 也应失效 API Key 认证缓存")
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateUser_NoInvalidateWhenRPMLimitUnchanged(t *testing.T) {
|
||||
base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10, Username: "old"}}
|
||||
repo := &rpmUserRepoStub{userRepoStub: base}
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: repo,
|
||||
redeemCodeRepo: &redeemRepoStub{},
|
||||
authCacheInvalidator: invalidator,
|
||||
}
|
||||
|
||||
newName := "new"
|
||||
sameRPM := 10
|
||||
_, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
|
||||
Username: &newName,
|
||||
RPMLimit: &sameRPM,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, invalidator.userIDs, "只改 username 不应触发认证缓存失效")
|
||||
}
|
||||
@@ -43,6 +43,13 @@ type APIKeyAuthUserSnapshot struct {
|
||||
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
|
||||
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
|
||||
TotalRecharged float64 `json:"total_recharged"`
|
||||
|
||||
// RPMLimit 用户级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 兜底判断。
|
||||
RPMLimit int `json:"rpm_limit"`
|
||||
|
||||
// UserGroupRPMOverride 该 API Key 对应的 (user, group) 专属 RPM 覆盖值。
|
||||
// nil = 无 override(回退到 group/user 级);0 = 不限流;>0 = 专属上限。
|
||||
UserGroupRPMOverride *int `json:"user_group_rpm_override,omitempty"`
|
||||
}
|
||||
|
||||
// APIKeyAuthGroupSnapshot 分组快照
|
||||
@@ -76,6 +83,9 @@ type APIKeyAuthGroupSnapshot struct {
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
|
||||
|
||||
// RPMLimit 分组级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 级联判断。
|
||||
RPMLimit int `json:"rpm_limit"`
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/dgraph-io/ristretto"
|
||||
)
|
||||
|
||||
const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold
|
||||
const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot
|
||||
|
||||
type apiKeyAuthCacheConfig struct {
|
||||
l1Size int
|
||||
@@ -176,7 +176,7 @@ func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey st
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
apiKey.Key = key
|
||||
snapshot := s.snapshotFromAPIKey(apiKey)
|
||||
snapshot := s.snapshotFromAPIKey(ctx, apiKey)
|
||||
if snapshot == nil {
|
||||
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
|
||||
}
|
||||
@@ -201,7 +201,7 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn
|
||||
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
if apiKey == nil || apiKey.User == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -232,8 +232,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
|
||||
BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
|
||||
TotalRecharged: apiKey.User.TotalRecharged,
|
||||
RPMLimit: apiKey.User.RPMLimit,
|
||||
},
|
||||
}
|
||||
|
||||
// 填充 (user, group) RPM override —— snapshot 构建时查一次 DB,后续请求零 DB 往返。
|
||||
if apiKey.GroupID != nil && *apiKey.GroupID > 0 && s.userGroupRateRepo != nil {
|
||||
override, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, apiKey.UserID, *apiKey.GroupID)
|
||||
if err == nil && override != nil {
|
||||
snapshot.User.UserGroupRPMOverride = override
|
||||
}
|
||||
// 查询失败或无 override 时留 nil,checkRPM 会回退到 DB 查询
|
||||
}
|
||||
if apiKey.Group != nil {
|
||||
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
||||
ID: apiKey.Group.ID,
|
||||
@@ -258,6 +268,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
|
||||
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
|
||||
MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig,
|
||||
RPMLimit: apiKey.Group.RPMLimit,
|
||||
}
|
||||
}
|
||||
return snapshot
|
||||
@@ -294,6 +305,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
|
||||
BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
|
||||
TotalRecharged: snapshot.User.TotalRecharged,
|
||||
RPMLimit: snapshot.User.RPMLimit,
|
||||
UserGroupRPMOverride: snapshot.User.UserGroupRPMOverride,
|
||||
},
|
||||
}
|
||||
if snapshot.Group != nil {
|
||||
@@ -321,6 +334,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
|
||||
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
|
||||
MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig,
|
||||
RPMLimit: snapshot.Group.RPMLimit,
|
||||
}
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
|
||||
@@ -263,7 +263,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
|
||||
},
|
||||
}
|
||||
|
||||
snapshot := svc.snapshotFromAPIKey(apiKey)
|
||||
snapshot := svc.snapshotFromAPIKey(context.Background(), apiKey)
|
||||
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
|
||||
|
||||
require.NotNil(t, roundTrip)
|
||||
|
||||
@@ -196,6 +196,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, "email")
|
||||
|
||||
// 新用户默认 RPM(0 = 不限制)。注册时写入,后续作为用户级兜底。
|
||||
var defaultRPMLimit int
|
||||
if s.settingService != nil {
|
||||
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
user := &User{
|
||||
Email: email,
|
||||
@@ -203,6 +209,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
Role: RoleUser,
|
||||
Balance: grantPlan.Balance,
|
||||
Concurrency: grantPlan.Concurrency,
|
||||
RPMLimit: defaultRPMLimit,
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
@@ -481,6 +488,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
|
||||
signupSource := inferLegacySignupSource(email)
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||
var defaultRPMLimit int
|
||||
if s.settingService != nil {
|
||||
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||
}
|
||||
|
||||
newUser := &User{
|
||||
Email: email,
|
||||
@@ -489,6 +500,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
Role: RoleUser,
|
||||
Balance: grantPlan.Balance,
|
||||
Concurrency: grantPlan.Concurrency,
|
||||
RPMLimit: defaultRPMLimit,
|
||||
Status: StatusActive,
|
||||
SignupSource: signupSource,
|
||||
}
|
||||
@@ -592,6 +604,10 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
|
||||
signupSource := inferLegacySignupSource(email)
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||
var defaultRPMLimit int
|
||||
if s.settingService != nil {
|
||||
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||
}
|
||||
|
||||
newUser := &User{
|
||||
Email: email,
|
||||
@@ -600,6 +616,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
Role: RoleUser,
|
||||
Balance: grantPlan.Balance,
|
||||
Concurrency: grantPlan.Concurrency,
|
||||
RPMLimit: defaultRPMLimit,
|
||||
Status: StatusActive,
|
||||
SignupSource: signupSource,
|
||||
}
|
||||
|
||||
@@ -20,6 +20,9 @@ import (
|
||||
var (
|
||||
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
|
||||
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
|
||||
// RPM 超限错误。gateway_handler 负责映射为 HTTP 429。
|
||||
ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded")
|
||||
ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded")
|
||||
)
|
||||
|
||||
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
||||
@@ -87,6 +90,8 @@ type BillingCacheService struct {
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
apiKeyRateLimitLoader apiKeyRateLimitLoader
|
||||
userRPMCache UserRPMCache
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cfg *config.Config
|
||||
circuitBreaker *billingCircuitBreaker
|
||||
|
||||
@@ -104,12 +109,22 @@ type BillingCacheService struct {
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService {
|
||||
func NewBillingCacheService(
|
||||
cache BillingCache,
|
||||
userRepo UserRepository,
|
||||
subRepo UserSubscriptionRepository,
|
||||
apiKeyRepo APIKeyRepository,
|
||||
userRPMCache UserRPMCache,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
cfg *config.Config,
|
||||
) *BillingCacheService {
|
||||
svc := &BillingCacheService{
|
||||
cache: cache,
|
||||
userRepo: userRepo,
|
||||
subRepo: subRepo,
|
||||
apiKeyRateLimitLoader: apiKeyRepo,
|
||||
userRPMCache: userRPMCache,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
|
||||
@@ -664,6 +679,95 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
|
||||
}
|
||||
}
|
||||
|
||||
// RPM 限流:级联回落(Override → Group → User),放在最后以避免为注定失败的请求增加计数。
|
||||
if err := s.checkRPM(ctx, user, group); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkRPM 执行并行 RPM 限流,所有适用的限制同时生效,任一超限即拒绝:
|
||||
//
|
||||
// 1. (用户, 分组) rpm_override — 最细粒度:管理员为特定用户在特定分组设定的专属限额。
|
||||
// override=0 表示该用户在该分组免检(绿灯),但 user 级全局上限仍然生效。
|
||||
// 2. group.rpm_limit — 分组级:该分组的统一 RPM 容量(仅当无 override 时生效)。
|
||||
// 3. user.rpm_limit — 用户级全局硬上限:无论 override/group 如何配置,始终生效。
|
||||
//
|
||||
// 与旧版"级联互斥"设计不同,新版确保 user.rpm_limit 作为全局天花板不会被 group 或 override 覆盖。
|
||||
// Redis 故障一律 fail-open(打 warning,不阻塞业务)。
|
||||
func (s *BillingCacheService) checkRPM(ctx context.Context, user *User, group *Group) error {
|
||||
if s == nil || s.userRPMCache == nil || user == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ── 第一层:分组级检查(override 或 group.rpm_limit) ──
|
||||
if group != nil {
|
||||
// 解析 override:优先从 auth cache snapshot,nil 时回退 DB。
|
||||
var override *int
|
||||
if user.UserGroupRPMOverride != nil {
|
||||
override = user.UserGroupRPMOverride
|
||||
} else if s.userGroupRateRepo != nil {
|
||||
dbOverride, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, user.ID, group.ID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf(
|
||||
"service.billing_cache",
|
||||
"Warning: rpm override lookup failed for user=%d group=%d: %v",
|
||||
user.ID, group.ID, err,
|
||||
)
|
||||
} else {
|
||||
override = dbOverride
|
||||
}
|
||||
}
|
||||
|
||||
if override != nil {
|
||||
// override=0 → 该用户在该分组免检(但 user 级仍会在下面检查)。
|
||||
if *override > 0 {
|
||||
count, incErr := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
|
||||
if incErr != nil {
|
||||
logger.LegacyPrintf(
|
||||
"service.billing_cache",
|
||||
"Warning: rpm increment (override) failed for user=%d group=%d: %v",
|
||||
user.ID, group.ID, incErr,
|
||||
)
|
||||
// fail-open
|
||||
} else if count > *override {
|
||||
return ErrGroupRPMExceeded
|
||||
}
|
||||
}
|
||||
// override 命中后跳过 group.rpm_limit(override 替代 group),但不 return——继续检查 user 级。
|
||||
} else if group.RPMLimit > 0 {
|
||||
// 无 override,检查 group.rpm_limit。
|
||||
count, err := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf(
|
||||
"service.billing_cache",
|
||||
"Warning: rpm increment (group) failed for user=%d group=%d: %v",
|
||||
user.ID, group.ID, err,
|
||||
)
|
||||
// fail-open
|
||||
} else if count > group.RPMLimit {
|
||||
return ErrGroupRPMExceeded
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── 第二层:用户级全局硬上限(始终生效) ──
|
||||
if user.RPMLimit > 0 {
|
||||
count, err := s.userRPMCache.IncrementUserRPM(ctx, user.ID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf(
|
||||
"service.billing_cache",
|
||||
"Warning: rpm increment (user) failed for user=%d: %v",
|
||||
user.ID, err,
|
||||
)
|
||||
return nil // fail-open
|
||||
}
|
||||
if count > user.RPMLimit {
|
||||
return ErrUserRPMExceeded
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
253
backend/internal/service/billing_cache_service_rpm_test.go
Normal file
253
backend/internal/service/billing_cache_service_rpm_test.go
Normal file
@@ -0,0 +1,253 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// userRPMCacheStub 记录每种计数器被调用的次数,并可注入返回值与错误。
|
||||
type userRPMCacheStub struct {
|
||||
userGroupCalls int32
|
||||
userCalls int32
|
||||
|
||||
userGroupCounts []int // 依次返回的计数值
|
||||
userGroupErr error
|
||||
userCounts []int
|
||||
userErr error
|
||||
}
|
||||
|
||||
func (s *userRPMCacheStub) IncrementUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
|
||||
idx := int(atomic.AddInt32(&s.userGroupCalls, 1)) - 1
|
||||
if s.userGroupErr != nil {
|
||||
return 0, s.userGroupErr
|
||||
}
|
||||
if idx < len(s.userGroupCounts) {
|
||||
return s.userGroupCounts[idx], nil
|
||||
}
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (s *userRPMCacheStub) IncrementUserRPM(_ context.Context, _ int64) (int, error) {
|
||||
idx := int(atomic.AddInt32(&s.userCalls, 1)) - 1
|
||||
if s.userErr != nil {
|
||||
return 0, s.userErr
|
||||
}
|
||||
if idx < len(s.userCounts) {
|
||||
return s.userCounts[idx], nil
|
||||
}
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (s *userRPMCacheStub) GetUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *userRPMCacheStub) GetUserRPM(_ context.Context, _ int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// rpmOverrideRepoStub 专用于 checkRPM 分支测试,只实现必要方法。
|
||||
type rpmOverrideRepoStub struct {
|
||||
UserGroupRateRepository
|
||||
|
||||
override *int
|
||||
err error
|
||||
calls int32
|
||||
}
|
||||
|
||||
func (s *rpmOverrideRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
|
||||
atomic.AddInt32(&s.calls, 1)
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return s.override, nil
|
||||
}
|
||||
|
||||
func newBillingServiceForRPM(t *testing.T, cache UserRPMCache, rateRepo UserGroupRateRepository) *BillingCacheService {
|
||||
t.Helper()
|
||||
// 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。
|
||||
// 我们只直接测 checkRPM。
|
||||
svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
return svc
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_OverrideTakesPrecedenceOverGroup(t *testing.T) {
|
||||
override := 2
|
||||
// user-group 计数: 1, 2, 3;user 计数: 默认返回 1(远小于 RPMLimit=100,不干扰)
|
||||
cache := &userRPMCacheStub{userGroupCounts: []int{1, 2, 3}}
|
||||
repo := &rpmOverrideRepoStub{override: &override}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 100} // 全局上限设高,不干扰 override 测试
|
||||
group := &Group{ID: 10, RPMLimit: 100}
|
||||
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded)
|
||||
|
||||
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userGroupCalls), "override 命中分支应走 user-group 计数")
|
||||
// 并行设计:前 2 次 override 未超→继续检查 user;第 3 次 override 超了→直接 return,不检查 user
|
||||
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userCalls), "override 超限前 user 计数器应被调用")
|
||||
require.EqualValues(t, 3, atomic.LoadInt32(&repo.calls))
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_UserLimitIsGlobalHardCap(t *testing.T) {
|
||||
override := 100 // override 很高
|
||||
// user-group 计数: 默认返回 1(远小于 override);user 计数: 1, 2, 3
|
||||
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
|
||||
repo := &rpmOverrideRepoStub{override: &override}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 2} // 全局硬上限=2,应覆盖 override=100
|
||||
group := &Group{ID: 10, RPMLimit: 100}
|
||||
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded, "user 全局硬上限应优先于 override")
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_OverrideZeroSkipsGroupButUserStillApplies(t *testing.T) {
|
||||
zero := 0
|
||||
// user 计数: 依次返回 1..6
|
||||
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3, 4, 5, 6}}
|
||||
repo := &rpmOverrideRepoStub{override: &zero}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 5}
|
||||
group := &Group{ID: 10, RPMLimit: 100}
|
||||
|
||||
// override=0 跳过分组计数,但 user.RPMLimit=5 仍生效
|
||||
for i := 0; i < 5; i++ {
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group), "request %d should pass", i+1)
|
||||
}
|
||||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded,
|
||||
"override=0 跳过分组但 user 全局上限仍应生效")
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不应触发分组计数器")
|
||||
require.EqualValues(t, 6, atomic.LoadInt32(&cache.userCalls), "user 计数器应被调用")
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_OverrideZeroAndUserZeroIsFullyUnlimited(t *testing.T) {
|
||||
zero := 0
|
||||
cache := &userRPMCacheStub{}
|
||||
repo := &rpmOverrideRepoStub{override: &zero}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 0} // user 也不限
|
||||
group := &Group{ID: 10, RPMLimit: 100}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
}
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不触发分组计数")
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls), "user.RPMLimit=0 也不触发用户计数")
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_NilOverrideFallsThroughToGroup(t *testing.T) {
|
||||
// user-group 计数: 5, 6;user 计数: 默认 1(不干扰)
|
||||
cache := &userRPMCacheStub{userGroupCounts: []int{5, 6}}
|
||||
repo := &rpmOverrideRepoStub{override: nil}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 999} // 全局上限很高,group 先超
|
||||
group := &Group{ID: 10, RPMLimit: 5}
|
||||
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group)) // ug=5, user=1, 都没超
|
||||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded) // ug=6 > 5
|
||||
|
||||
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userGroupCalls))
|
||||
// 并行模式:第 1 次 group 没超 → 继续检查 user;第 2 次 group 超了 → 直接 return,不检查 user
|
||||
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userCalls), "group 未超时 user 也应检查;group 超时直接返回")
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_OverrideLookupErrorFallsThroughToGroup(t *testing.T) {
|
||||
cache := &userRPMCacheStub{userGroupCounts: []int{3}}
|
||||
repo := &rpmOverrideRepoStub{err: errors.New("db down")}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 0}
|
||||
group := &Group{ID: 10, RPMLimit: 10}
|
||||
|
||||
// override 查询失败后应继续尝试 group 分支(不直接拒绝)
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
|
||||
require.EqualValues(t, 1, atomic.LoadInt32(&repo.calls))
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_UserLevelFallbackWhenGroupUnlimited(t *testing.T) {
|
||||
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
|
||||
repo := &rpmOverrideRepoStub{override: nil}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 2}
|
||||
group := &Group{ID: 10, RPMLimit: 0} // 分组未设限
|
||||
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded)
|
||||
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "group 未设限时不应 INCR user-group 键")
|
||||
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_NoLimitsConfiguredIsNoop(t *testing.T) {
|
||||
cache := &userRPMCacheStub{}
|
||||
repo := &rpmOverrideRepoStub{override: nil}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 0}
|
||||
group := &Group{ID: 10, RPMLimit: 0}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
}
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_RedisErrorFailOpen(t *testing.T) {
|
||||
cache := &userRPMCacheStub{userGroupErr: errors.New("redis unavailable")}
|
||||
repo := &rpmOverrideRepoStub{override: nil}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 0}
|
||||
group := &Group{ID: 10, RPMLimit: 5}
|
||||
|
||||
// Redis 故障时应 fail-open,不拒绝请求
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, group))
|
||||
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_NoGroupUsesUserOnly(t *testing.T) {
|
||||
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
|
||||
repo := &rpmOverrideRepoStub{}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
user := &User{ID: 1, RPMLimit: 2}
|
||||
|
||||
// 无 group(纯用户级限流场景),不应查询 rpm_override。
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
|
||||
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
|
||||
require.ErrorIs(t, svc.checkRPM(context.Background(), user, nil), ErrUserRPMExceeded)
|
||||
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls), "无 group 时不应查询 rpm_override")
|
||||
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
|
||||
}
|
||||
|
||||
func TestBillingCacheService_CheckRPM_NilUserIsNoop(t *testing.T) {
|
||||
cache := &userRPMCacheStub{}
|
||||
repo := &rpmOverrideRepoStub{}
|
||||
svc := newBillingServiceForRPM(t, cache, repo)
|
||||
|
||||
require.NoError(t, svc.checkRPM(context.Background(), nil, &Group{ID: 1, RPMLimit: 10}))
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
|
||||
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls))
|
||||
}
|
||||
@@ -100,7 +100,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
|
||||
delay: 80 * time.Millisecond,
|
||||
balance: 12.34,
|
||||
}
|
||||
svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{})
|
||||
svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
const goroutines = 16
|
||||
|
||||
@@ -70,7 +70,7 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context,
|
||||
|
||||
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
||||
cache := &billingCacheWorkerStub{}
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
start := time.Now()
|
||||
@@ -92,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
||||
|
||||
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
|
||||
cache := &billingCacheWorkerStub{}
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
|
||||
svc.Stop()
|
||||
|
||||
enqueued := svc.enqueueCacheWrite(cacheWriteTask{
|
||||
|
||||
@@ -170,9 +170,10 @@ const (
|
||||
SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组)
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
|
||||
SettingKeyDefaultUserRPMLimit = "default_user_rpm_limit" // 新用户默认 RPM 限制(0 = 不限制)
|
||||
|
||||
// 第三方认证来源默认授予配置
|
||||
SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
|
||||
|
||||
@@ -59,6 +59,10 @@ type Group struct {
|
||||
DefaultMappedModel string
|
||||
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
|
||||
|
||||
// RPMLimit 分组级每分钟请求数上限(0 = 不限制)。
|
||||
// 一旦设置即接管该分组用户的限流(覆盖用户级 rpm_limit),可被 user-group rpm_override 进一步覆盖。
|
||||
RPMLimit int
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
|
||||
@@ -1060,6 +1060,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
// 默认配置
|
||||
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
|
||||
updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit)
|
||||
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal default subscriptions: %w", err)
|
||||
@@ -1422,6 +1423,18 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
|
||||
return s.cfg.Default.UserBalance
|
||||
}
|
||||
|
||||
// GetDefaultUserRPMLimit 获取新用户默认 RPM 限制(0 = 不限制)。未配置则返回 0。
|
||||
func (s *SettingService) GetDefaultUserRPMLimit(ctx context.Context) int {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultUserRPMLimit)
|
||||
if err != nil || value == "" {
|
||||
return 0
|
||||
}
|
||||
if v, err := strconv.Atoi(value); err == nil && v >= 0 {
|
||||
return v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetDefaultSubscriptions 获取新用户默认订阅配置列表。
|
||||
func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions)
|
||||
@@ -1590,6 +1603,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyOIDCConnectUserInfoUsernamePath: "",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
SettingKeyDefaultUserRPMLimit: "0",
|
||||
SettingKeyDefaultSubscriptions: "[]",
|
||||
SettingKeyAuthSourceDefaultEmailBalance: "0",
|
||||
SettingKeyAuthSourceDefaultEmailConcurrency: "5",
|
||||
@@ -1699,6 +1713,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result.DefaultConcurrency = s.cfg.Default.UserConcurrency
|
||||
}
|
||||
|
||||
if rpm, err := strconv.Atoi(settings[SettingKeyDefaultUserRPMLimit]); err == nil && rpm >= 0 {
|
||||
result.DefaultUserRPMLimit = rpm
|
||||
}
|
||||
|
||||
// 解析浮点数类型
|
||||
if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil {
|
||||
result.DefaultBalance = balance
|
||||
|
||||
@@ -106,6 +106,7 @@ type SystemSettings struct {
|
||||
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
DefaultUserRPMLimit int
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting
|
||||
|
||||
// Model fallback configuration
|
||||
|
||||
@@ -49,6 +49,15 @@ type User struct {
|
||||
BalanceNotifyExtraEmails []NotifyEmailEntry
|
||||
TotalRecharged float64
|
||||
|
||||
// RPMLimit 用户级每分钟请求数上限(0 = 不限制)。仅在所用分组未设置 rpm_limit
|
||||
// 且该 (用户, 分组) 无 rpm_override 时作为全局兜底生效,计数键 rpm:u:{userID}:{min}。
|
||||
RPMLimit int
|
||||
|
||||
// UserGroupRPMOverride 来自 auth cache snapshot 的 (user, group) RPM 覆盖值。
|
||||
// nil = 该 API Key 对应的 (user, group) 无 override;非 nil 时 checkRPM 直接使用,
|
||||
// 避免每请求查 DB。字段不持久化到数据库。
|
||||
UserGroupRPMOverride *int
|
||||
|
||||
APIKeys []APIKey
|
||||
Subscriptions []UserSubscription
|
||||
}
|
||||
|
||||
@@ -2,14 +2,16 @@ package service
|
||||
|
||||
import "context"
|
||||
|
||||
// UserGroupRateEntry 分组下用户专属倍率条目
|
||||
// UserGroupRateEntry 分组下用户专属倍率/RPM 条目。
|
||||
// RateMultiplier 与 RPMOverride 均为指针以支持"未设置"语义(NULL)。
|
||||
type UserGroupRateEntry struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
UserName string `json:"user_name"`
|
||||
UserEmail string `json:"user_email"`
|
||||
UserNotes string `json:"user_notes"`
|
||||
UserStatus string `json:"user_status"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
UserID int64 `json:"user_id"`
|
||||
UserName string `json:"user_name"`
|
||||
UserEmail string `json:"user_email"`
|
||||
UserNotes string `json:"user_notes"`
|
||||
UserStatus string `json:"user_status"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier,omitempty"`
|
||||
RPMOverride *int `json:"rpm_override,omitempty"`
|
||||
}
|
||||
|
||||
// GroupRateMultiplierInput 批量设置分组倍率的输入条目
|
||||
@@ -18,30 +20,44 @@ type GroupRateMultiplierInput struct {
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
}
|
||||
|
||||
// UserGroupRateRepository 用户专属分组倍率仓储接口
|
||||
// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
|
||||
// GroupRPMOverrideInput 批量设置分组 RPM override 的输入条目。
|
||||
// RPMOverride 为 *int 以支持清除(nil)语义。
|
||||
type GroupRPMOverrideInput struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
RPMOverride *int `json:"rpm_override"`
|
||||
}
|
||||
|
||||
// UserGroupRateRepository 用户专属分组倍率/RPM 仓储接口。
|
||||
// 允许管理员为特定用户设置分组的专属计费倍率与 RPM 上限,覆盖分组默认值。
|
||||
type UserGroupRateRepository interface {
|
||||
// GetByUserID 获取用户的所有专属分组倍率
|
||||
// 返回 map[groupID]rateMultiplier
|
||||
// GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目)
|
||||
GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error)
|
||||
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||
// 如果未设置专属倍率,返回 nil
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil)
|
||||
GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error)
|
||||
|
||||
// GetByGroupID 获取指定分组下所有用户的专属倍率
|
||||
// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil)
|
||||
GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error)
|
||||
|
||||
// GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回)
|
||||
GetByGroupID(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
|
||||
|
||||
// SyncUserGroupRates 同步用户的分组专属倍率
|
||||
// rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率
|
||||
// SyncUserGroupRates 同步用户的分组专属倍率;nil 表示清空该分组的 rate_multiplier
|
||||
SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error
|
||||
|
||||
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组数据)
|
||||
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组 rate 部分)
|
||||
SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用)
|
||||
// SyncGroupRPMOverrides 批量同步分组的用户专属 RPM(替换整组 rpm_override 部分)。
|
||||
// 条目中 RPMOverride 为 nil 时清空对应行的 rpm_override;非 nil 时 upsert。
|
||||
SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error
|
||||
|
||||
// ClearGroupRPMOverrides 清空指定分组的所有 rpm_override(整组 rpm 部分归 NULL)
|
||||
ClearGroupRPMOverrides(ctx context.Context, groupID int64) error
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属条目(分组删除时调用)
|
||||
DeleteByGroupID(ctx context.Context, groupID int64) error
|
||||
|
||||
// DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用)
|
||||
// DeleteByUserID 删除指定用户的所有专属条目(用户删除时调用)
|
||||
DeleteByUserID(ctx context.Context, userID int64) error
|
||||
}
|
||||
|
||||
25
backend/internal/service/user_rpm_cache.go
Normal file
25
backend/internal/service/user_rpm_cache.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
// UserRPMCache 用户/分组级 RPM 计数器接口。
|
||||
//
|
||||
// 与账号级 RPMCache 的区别:
|
||||
// - RPMCache —— 按外部 AI provider 账号聚合(key: rpm:{accountID}:{min})。
|
||||
// - UserRPMCache —— 按用户或 (用户, 分组) 聚合,杜绝"同一用户创建多个 API Key 绕过 RPM"的路径。
|
||||
// key 形如 rpm:ug:{userID}:{groupID}:{min} 或 rpm:u:{userID}:{min}。
|
||||
type UserRPMCache interface {
|
||||
// IncrementUserGroupRPM 原子递增 (user, group) 级分钟计数并返回最新值。
|
||||
// 用于分组 rpm_limit 与 user-group rpm_override 两种命中分支。
|
||||
IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (count int, err error)
|
||||
|
||||
// IncrementUserRPM 原子递增用户级分钟计数并返回最新值。
|
||||
// 用于用户全局 rpm_limit 兜底分支(分组未设且无 override 时)。
|
||||
IncrementUserRPM(ctx context.Context, userID int64) (count int, err error)
|
||||
|
||||
// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读,不递增)。
|
||||
GetUserGroupRPM(ctx context.Context, userID, groupID int64) (count int, err error)
|
||||
|
||||
// GetUserRPM 获取用户当前分钟已用 RPM(只读,不递增)。
|
||||
GetUserRPM(ctx context.Context, userID int64) (count int, err error)
|
||||
}
|
||||
@@ -39,6 +39,11 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
|
||||
return NewEmailQueueService(emailService, 3)
|
||||
}
|
||||
|
||||
// ProvideOAuthRefreshAPI creates OAuthRefreshAPI with the default lock TTL.
|
||||
func ProvideOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI {
|
||||
return NewOAuthRefreshAPI(accountRepo, tokenCache)
|
||||
}
|
||||
|
||||
// ProvideTokenRefreshService creates and starts TokenRefreshService
|
||||
func ProvideTokenRefreshService(
|
||||
accountRepo AccountRepository,
|
||||
@@ -383,6 +388,19 @@ func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupReposit
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideBillingCacheService wires BillingCacheService with its RPM dependencies.
|
||||
func ProvideBillingCacheService(
|
||||
cache BillingCache,
|
||||
userRepo UserRepository,
|
||||
subRepo UserSubscriptionRepository,
|
||||
apiKeyRepo APIKeyRepository,
|
||||
rpmCache UserRPMCache,
|
||||
rateRepo UserGroupRateRepository,
|
||||
cfg *config.Config,
|
||||
) *BillingCacheService {
|
||||
return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg)
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all services
|
||||
var ProviderSet = wire.NewSet(
|
||||
// Core services
|
||||
@@ -399,7 +417,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewDashboardService,
|
||||
ProvidePricingService,
|
||||
NewBillingService,
|
||||
NewBillingCacheService,
|
||||
ProvideBillingCacheService,
|
||||
NewAnnouncementService,
|
||||
NewAdminService,
|
||||
NewGatewayService,
|
||||
@@ -411,7 +429,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewCompositeTokenCacheInvalidator,
|
||||
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
|
||||
NewAntigravityOAuthService,
|
||||
NewOAuthRefreshAPI,
|
||||
ProvideOAuthRefreshAPI,
|
||||
ProvideGeminiTokenProvider,
|
||||
NewGeminiMessagesCompatService,
|
||||
ProvideAntigravityTokenProvider,
|
||||
|
||||
Reference in New Issue
Block a user