feat(rpm): RPM 限流模块优化

P0:
- rpm_override 嵌入 Auth Cache Snapshot,消除每请求 DB 查询 (snapshot v6→v7)
- 429 RPM 响应返回 Retry-After 头(当前分钟剩余秒数)

P1:
- ClearAll 按钮直连 DELETE API,带 loading 防重复
- 新增 GET /admin/users/:id/rpm-status 管理员 RPM 用量查询端点

优化:
- checkRPM 从级联互斥改为并行取最严,user.rpm_limit 作为全局硬上限始终生效
- Override/Group 变更后自动失效 auth cache
- fail-open 语义不变,Redis 故障不阻塞业务
This commit is contained in:
james-6-23
2026-04-23 03:33:52 +08:00
parent ef967d8f8a
commit dc5d42addc
79 changed files with 2831 additions and 140 deletions

View File

@@ -8,6 +8,7 @@ import (
"io"
"log/slog"
"net/http"
"sort"
"strings"
"time"
@@ -32,6 +33,7 @@ type AdminService interface {
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error)
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// codeType is optional - pass empty string to return all types.
// Also returns totalRecharged (sum of all positive balance top-ups).
@@ -50,6 +52,8 @@ type AdminService interface {
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
ClearGroupRateMultipliers(ctx context.Context, groupID int64) error
BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
ClearGroupRPMOverrides(ctx context.Context, groupID int64) error
BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
// API Key management (admin)
@@ -114,6 +118,7 @@ type CreateUserInput struct {
Notes string
Balance float64
Concurrency int
RPMLimit int
AllowedGroups []int64
}
@@ -124,6 +129,7 @@ type UpdateUserInput struct {
Notes *string
Balance *float64 // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0"
RPMLimit *int // 使用指针区分"未提供"和"设置为0"
Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
@@ -199,6 +205,8 @@ type CreateGroupInput struct {
RequireOAuthOnly bool
RequirePrivacySet bool
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
// RPMLimit 分组 RPM 上限0 = 不限制)
RPMLimit int
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
}
@@ -234,6 +242,8 @@ type UpdateGroupInput struct {
RequireOAuthOnly *bool
RequirePrivacySet *bool
MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig
// RPMLimit 分组 RPM 上限0 = 不限制nil 表示未提供不改动。
RPMLimit *int
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
}
@@ -317,6 +327,22 @@ type ReplaceUserGroupResult struct {
MigratedKeys int64 // 迁移的 Key 数量
}
// UserRPMStatus describes a user's current per-minute RPM usage.
type UserRPMStatus struct {
UserRPMUsed int `json:"user_rpm_used"`
UserRPMLimit int `json:"user_rpm_limit"`
PerGroup []UserGroupRPMStatus `json:"per_group"`
}
// UserGroupRPMStatus describes current per-minute RPM usage for one user/group pair.
type UserGroupRPMStatus struct {
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Used int `json:"used"`
Limit int `json:"limit"`
Source string `json:"source"` // "group" | "override"
}
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct {
Success int `json:"success"`
@@ -463,6 +489,8 @@ const (
proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36"
)
var ErrRPMStatusUnavailable = infraerrors.New(http.StatusNotImplemented, "RPM_STATUS_UNAVAILABLE", "RPM cache not available")
// adminServiceImpl implements AdminService
type adminServiceImpl struct {
userRepo UserRepository
@@ -472,6 +500,7 @@ type adminServiceImpl struct {
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
userGroupRateRepo UserGroupRateRepository
userRPMCache UserRPMCache
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache
@@ -496,6 +525,7 @@ func NewAdminService(
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
userGroupRateRepo UserGroupRateRepository,
userRPMCache UserRPMCache,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache,
@@ -514,6 +544,7 @@ func NewAdminService(
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
userGroupRateRepo: userGroupRateRepo,
userRPMCache: userRPMCache,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache,
@@ -617,6 +648,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
RPMLimit: input.RPMLimit,
Status: StatusActive,
AllowedGroups: input.AllowedGroups,
}
@@ -670,6 +702,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
oldConcurrency := user.Concurrency
oldStatus := user.Status
oldRole := user.Role
oldRPMLimit := user.RPMLimit
if input.Email != "" {
user.Email = input.Email
@@ -695,6 +728,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user.Concurrency = *input.Concurrency
}
if input.RPMLimit != nil {
user.RPMLimit = *input.RPMLimit
}
if input.AllowedGroups != nil {
user.AllowedGroups = *input.AllowedGroups
}
@@ -711,7 +748,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
}
if s.authCacheInvalidator != nil {
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
// RPMLimit 直接参与 billing_cache_service.checkRPM 的三级级联,
// 不失效缓存会让修改在一个 L2 TTL 内失去效果。
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole || user.RPMLimit != oldRPMLimit {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
}
}
@@ -833,6 +872,81 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag
return keys, result.Total, nil
}
func (s *adminServiceImpl) GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) {
if s.userRPMCache == nil {
return nil, ErrRPMStatusUnavailable
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
userRPMUsed, err := s.userRPMCache.GetUserRPM(ctx, userID)
if err != nil {
logger.LegacyPrintf("service.admin", "failed to get user rpm: user_id=%d err=%v", userID, err)
}
keys, _, err := s.GetUserAPIKeys(ctx, userID, 1, 1000, "", "")
if err != nil {
return nil, err
}
groupIDSet := make(map[int64]struct{})
for _, key := range keys {
if key.GroupID != nil && *key.GroupID > 0 {
groupIDSet[*key.GroupID] = struct{}{}
}
}
groupIDs := make([]int64, 0, len(groupIDSet))
for groupID := range groupIDSet {
groupIDs = append(groupIDs, groupID)
}
sort.Slice(groupIDs, func(i, j int) bool { return groupIDs[i] < groupIDs[j] })
var perGroup []UserGroupRPMStatus
for _, groupID := range groupIDs {
used, getErr := s.userRPMCache.GetUserGroupRPM(ctx, userID, groupID)
if getErr != nil {
logger.LegacyPrintf("service.admin", "failed to get user group rpm: user_id=%d group_id=%d err=%v", userID, groupID, getErr)
}
entry := UserGroupRPMStatus{
GroupID: groupID,
Used: used,
}
if s.groupRepo != nil {
if group, groupErr := s.groupRepo.GetByIDLite(ctx, groupID); groupErr == nil && group != nil {
entry.GroupName = group.Name
entry.Limit = group.RPMLimit
entry.Source = "group"
} else if groupErr != nil {
logger.LegacyPrintf("service.admin", "failed to get group rpm status metadata: group_id=%d err=%v", groupID, groupErr)
}
}
if s.userGroupRateRepo != nil {
override, overrideErr := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, userID, groupID)
if overrideErr != nil {
logger.LegacyPrintf("service.admin", "failed to get rpm override: user_id=%d group_id=%d err=%v", userID, groupID, overrideErr)
} else if override != nil {
entry.Limit = *override
entry.Source = "override"
}
}
perGroup = append(perGroup, entry)
}
return &UserRPMStatus{
UserRPMUsed: userRPMUsed,
UserRPMLimit: user.RPMLimit,
PerGroup: perGroup,
}, nil
}
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
// Return mock data for now
return map[string]any{
@@ -1314,6 +1428,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
RequirePrivacySet: input.RequirePrivacySet,
DefaultMappedModel: input.DefaultMappedModel,
MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig),
RPMLimit: input.RPMLimit,
}
sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Create(ctx, group); err != nil {
@@ -1548,12 +1663,19 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.MessagesDispatchModelConfig != nil {
group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig)
}
if input.RPMLimit != nil {
group.RPMLimit = *input.RPMLimit
}
sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
if len(input.CopyAccountsFromGroupIDs) > 0 {
// 去重源分组 IDs
@@ -1622,9 +1744,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
}
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
return group, nil
}
@@ -1700,6 +1819,39 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
}
func (s *adminServiceImpl) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
if s.userGroupRateRepo == nil {
return nil
}
if err := s.userGroupRateRepo.ClearGroupRPMOverrides(ctx, groupID); err != nil {
return err
}
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
}
return nil
}
func (s *adminServiceImpl) BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
if s.userGroupRateRepo == nil {
return nil
}
for _, e := range entries {
if e.RPMOverride != nil && *e.RPMOverride < 0 {
return infraerrors.BadRequest("INVALID_RPM_OVERRIDE", fmt.Sprintf("rpm_override must be >= 0 (user_id=%d)", e.UserID))
}
}
if err := s.userGroupRateRepo.SyncGroupRPMOverrides(ctx, groupID, entries); err != nil {
return err
}
// RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
}
return nil
}
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return s.groupRepo.UpdateSortOrders(ctx, updates)
}

View File

@@ -5,8 +5,10 @@ package service
import (
"context"
"errors"
"net/http"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
@@ -21,6 +23,10 @@ type userGroupRateRepoStubForGroupRate struct {
syncedGroupID int64
syncedEntries []GroupRateMultiplierInput
syncGroupErr error
rpmSyncedGroupID int64
rpmSyncedEntries []GroupRPMOverrideInput
rpmSyncErr error
}
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
@@ -31,6 +37,10 @@ func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context,
panic("unexpected GetByUserAndGroup call")
}
func (s *userGroupRateRepoStubForGroupRate) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
panic("unexpected GetRPMOverrideByUserAndGroup call")
}
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
if s.getByGroupIDErr != nil {
return nil, s.getByGroupIDErr
@@ -48,6 +58,16 @@ func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.C
return s.syncGroupErr
}
func (s *userGroupRateRepoStubForGroupRate) SyncGroupRPMOverrides(_ context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
s.rpmSyncedGroupID = groupID
s.rpmSyncedEntries = entries
return s.rpmSyncErr
}
func (s *userGroupRateRepoStubForGroupRate) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
panic("unexpected ClearGroupRPMOverrides call")
}
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
return s.deleteByGroupErr
@@ -62,8 +82,8 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDData: map[int64][]UserGroupRateEntry{
10: {
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5},
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8},
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: ptrFloat(1.5)},
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: ptrFloat(0.8)},
},
},
}
@@ -74,9 +94,11 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
require.Len(t, entries, 2)
require.Equal(t, int64(1), entries[0].UserID)
require.Equal(t, "alice", entries[0].UserName)
require.Equal(t, 1.5, entries[0].RateMultiplier)
require.NotNil(t, entries[0].RateMultiplier)
require.Equal(t, 1.5, *entries[0].RateMultiplier)
require.Equal(t, int64(2), entries[1].UserID)
require.Equal(t, 0.8, entries[1].RateMultiplier)
require.NotNil(t, entries[1].RateMultiplier)
require.Equal(t, 0.8, *entries[1].RateMultiplier)
})
t.Run("returns nil when repo is nil", func(t *testing.T) {
@@ -174,3 +196,30 @@ func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
require.Contains(t, err.Error(), "sync failed")
})
}
func TestAdminService_BatchSetGroupRPMOverrides(t *testing.T) {
t.Run("syncs entries to repo", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{}
svc := &adminServiceImpl{userGroupRateRepo: repo}
override := 20
entries := []GroupRPMOverrideInput{{UserID: 2, RPMOverride: &override}}
err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, entries)
require.NoError(t, err)
require.Equal(t, int64(10), repo.rpmSyncedGroupID)
require.Equal(t, entries, repo.rpmSyncedEntries)
})
t.Run("rejects negative override as bad request", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{}
svc := &adminServiceImpl{userGroupRateRepo: repo}
negative := -1
err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, []GroupRPMOverrideInput{
{UserID: 2, RPMOverride: &negative},
})
require.Error(t, err)
require.Equal(t, http.StatusBadRequest, infraerrors.Code(err))
require.Zero(t, repo.rpmSyncedGroupID)
})
}

View File

@@ -266,6 +266,31 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.Nil(t, repo.updated.ImagePrice4K)
}
func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-group",
Platform: PlatformAnthropic,
Status: StatusActive,
RPMLimit: 10,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
groupRepo: repo,
authCacheInvalidator: invalidator,
}
rpmLimit := 60
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
RPMLimit: &rpmLimit,
})
require.NoError(t, err)
require.NotNil(t, group)
require.Equal(t, 60, repo.updated.RPMLimit)
require.Equal(t, []int64{1}, invalidator.groupIDs, "分组 RPMLimit 写入 auth snapshot变更后必须失效 API Key 认证缓存")
}
func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}

View File

@@ -89,6 +89,10 @@ func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context,
panic("unexpected GetByUserAndGroup call")
}
func (s *userGroupRateRepoStubForListUsers) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
panic("unexpected GetRPMOverrideByUserAndGroup call")
}
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
panic("unexpected SyncUserGroupRates call")
}
@@ -101,6 +105,14 @@ func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.C
panic("unexpected SyncGroupRateMultipliers call")
}
func (s *userGroupRateRepoStubForListUsers) SyncGroupRPMOverrides(_ context.Context, _ int64, _ []GroupRPMOverrideInput) error {
panic("unexpected SyncGroupRPMOverrides call")
}
func (s *userGroupRateRepoStubForListUsers) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
panic("unexpected ClearGroupRPMOverrides call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
panic("unexpected DeleteByGroupID call")
}

View File

@@ -0,0 +1,112 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type rpmStatusUserRepoStub struct {
UserRepository
user *User
}
func (s *rpmStatusUserRepoStub) GetByID(_ context.Context, _ int64) (*User, error) {
return s.user, nil
}
type rpmStatusAPIKeyRepoStub struct {
APIKeyRepository
keys []APIKey
}
func (s *rpmStatusAPIKeyRepoStub) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
return s.keys, &pagination.PaginationResult{Total: int64(len(s.keys))}, nil
}
type rpmStatusGroupRepoStub struct {
GroupRepository
groups map[int64]*Group
}
func (s *rpmStatusGroupRepoStub) GetByIDLite(_ context.Context, id int64) (*Group, error) {
return s.groups[id], nil
}
type rpmStatusRateRepoStub struct {
UserGroupRateRepository
overrides map[int64]*int
}
func (s *rpmStatusRateRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, groupID int64) (*int, error) {
return s.overrides[groupID], nil
}
type rpmStatusCacheStub struct {
UserRPMCache
userUsed int
groupUsed map[int64]int
}
func (s *rpmStatusCacheStub) IncrementUserGroupRPM(context.Context, int64, int64) (int, error) {
return 0, nil
}
func (s *rpmStatusCacheStub) IncrementUserRPM(context.Context, int64) (int, error) {
return 0, nil
}
func (s *rpmStatusCacheStub) GetUserGroupRPM(_ context.Context, _, groupID int64) (int, error) {
return s.groupUsed[groupID], nil
}
func (s *rpmStatusCacheStub) GetUserRPM(context.Context, int64) (int, error) {
return s.userUsed, nil
}
func TestAdminService_GetUserRPMStatus_AggregatesUserAndGroupLimits(t *testing.T) {
groupOneID := int64(1)
groupTwoID := int64(2)
override := 7
svc := &adminServiceImpl{
userRepo: &rpmStatusUserRepoStub{user: &User{
ID: 42,
RPMLimit: 20,
}},
apiKeyRepo: &rpmStatusAPIKeyRepoStub{keys: []APIKey{
{ID: 100, UserID: 42, GroupID: &groupTwoID},
{ID: 101, UserID: 42, GroupID: &groupOneID},
{ID: 102, UserID: 42, GroupID: &groupTwoID},
{ID: 103, UserID: 42},
}},
groupRepo: &rpmStatusGroupRepoStub{groups: map[int64]*Group{
groupOneID: {ID: groupOneID, Name: "group-one", RPMLimit: 10},
groupTwoID: {ID: groupTwoID, Name: "group-two", RPMLimit: 60},
}},
userGroupRateRepo: &rpmStatusRateRepoStub{overrides: map[int64]*int{
groupTwoID: &override,
}},
userRPMCache: &rpmStatusCacheStub{
userUsed: 5,
groupUsed: map[int64]int{
groupOneID: 3,
groupTwoID: 4,
},
},
}
status, err := svc.GetUserRPMStatus(context.Background(), 42)
require.NoError(t, err)
require.Equal(t, &UserRPMStatus{
UserRPMUsed: 5,
UserRPMLimit: 20,
PerGroup: []UserGroupRPMStatus{
{GroupID: groupOneID, GroupName: "group-one", Used: 3, Limit: 10, Source: "group"},
{GroupID: groupTwoID, GroupName: "group-two", Used: 4, Limit: 7, Source: "override"},
},
}, status)
}

View File

@@ -0,0 +1,69 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
// rpmUserRepoStub 复用 admin_service_update_balance_test.go 的基础 stub 结构,
// 只在 Update 时把入参克隆一份,便于断言修改后的 RPMLimit。
type rpmUserRepoStub struct {
*userRepoStub
lastUpdated *User
}
func (s *rpmUserRepoStub) Update(_ context.Context, user *User) error {
if user == nil {
return nil
}
clone := *user
s.lastUpdated = &clone
if s.userRepoStub != nil {
s.userRepoStub.user = &clone
}
return nil
}
func TestAdminService_UpdateUser_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10}}
repo := &rpmUserRepoStub{userRepoStub: base}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: &redeemRepoStub{},
authCacheInvalidator: invalidator,
}
newRPM := 60
updated, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
RPMLimit: &newRPM,
})
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, 60, updated.RPMLimit)
require.Equal(t, []int64{42}, invalidator.userIDs, "仅修改 RPMLimit 也应失效 API Key 认证缓存")
}
func TestAdminService_UpdateUser_NoInvalidateWhenRPMLimitUnchanged(t *testing.T) {
base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10, Username: "old"}}
repo := &rpmUserRepoStub{userRepoStub: base}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: &redeemRepoStub{},
authCacheInvalidator: invalidator,
}
newName := "new"
sameRPM := 10
_, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
Username: &newName,
RPMLimit: &sameRPM,
})
require.NoError(t, err)
require.Empty(t, invalidator.userIDs, "只改 username 不应触发认证缓存失效")
}

View File

@@ -43,6 +43,13 @@ type APIKeyAuthUserSnapshot struct {
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
TotalRecharged float64 `json:"total_recharged"`
// RPMLimit 用户级每分钟请求数上限0 = 不限制);用于 billing_cache_service.checkRPM 兜底判断。
RPMLimit int `json:"rpm_limit"`
// UserGroupRPMOverride 该 API Key 对应的 (user, group) 专属 RPM 覆盖值。
// nil = 无 override回退到 group/user 级0 = 不限流;>0 = 专属上限。
UserGroupRPMOverride *int `json:"user_group_rpm_override,omitempty"`
}
// APIKeyAuthGroupSnapshot 分组快照
@@ -76,6 +83,9 @@ type APIKeyAuthGroupSnapshot struct {
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
// RPMLimit 分组级每分钟请求数上限0 = 不限制);用于 billing_cache_service.checkRPM 级联判断。
RPMLimit int `json:"rpm_limit"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存

View File

@@ -14,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto"
)
const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold
const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot
type apiKeyAuthCacheConfig struct {
l1Size int
@@ -176,7 +176,7 @@ func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey st
return nil, fmt.Errorf("get api key: %w", err)
}
apiKey.Key = key
snapshot := s.snapshotFromAPIKey(apiKey)
snapshot := s.snapshotFromAPIKey(ctx, apiKey)
if snapshot == nil {
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
}
@@ -201,7 +201,7 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
}
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) *APIKeyAuthSnapshot {
if apiKey == nil || apiKey.User == nil {
return nil
}
@@ -232,8 +232,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
TotalRecharged: apiKey.User.TotalRecharged,
RPMLimit: apiKey.User.RPMLimit,
},
}
// 填充 (user, group) RPM override —— snapshot 构建时查一次 DB后续请求零 DB 往返。
if apiKey.GroupID != nil && *apiKey.GroupID > 0 && s.userGroupRateRepo != nil {
override, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, apiKey.UserID, *apiKey.GroupID)
if err == nil && override != nil {
snapshot.User.UserGroupRPMOverride = override
}
// 查询失败或无 override 时留 nilcheckRPM 会回退到 DB 查询
}
if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{
ID: apiKey.Group.ID,
@@ -258,6 +268,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig,
RPMLimit: apiKey.Group.RPMLimit,
}
}
return snapshot
@@ -294,6 +305,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
TotalRecharged: snapshot.User.TotalRecharged,
RPMLimit: snapshot.User.RPMLimit,
UserGroupRPMOverride: snapshot.User.UserGroupRPMOverride,
},
}
if snapshot.Group != nil {
@@ -321,6 +334,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig,
RPMLimit: snapshot.Group.RPMLimit,
}
}
s.compileAPIKeyIPRules(apiKey)

View File

@@ -263,7 +263,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
},
}
snapshot := svc.snapshotFromAPIKey(apiKey)
snapshot := svc.snapshotFromAPIKey(context.Background(), apiKey)
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
require.NotNil(t, roundTrip)

View File

@@ -196,6 +196,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
grantPlan := s.resolveSignupGrantPlan(ctx, "email")
// 新用户默认 RPM0 = 不限制)。注册时写入,后续作为用户级兜底。
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
// 创建用户
user := &User{
Email: email,
@@ -203,6 +209,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
Role: RoleUser,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive,
}
@@ -481,6 +488,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
signupSource := inferLegacySignupSource(email)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
newUser := &User{
Email: email,
@@ -489,6 +500,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
Role: RoleUser,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive,
SignupSource: signupSource,
}
@@ -592,6 +604,10 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
signupSource := inferLegacySignupSource(email)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
newUser := &User{
Email: email,
@@ -600,6 +616,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Role: RoleUser,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive,
SignupSource: signupSource,
}

View File

@@ -20,6 +20,9 @@ import (
var (
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
// RPM 超限错误。gateway_handler 负责映射为 HTTP 429。
ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded")
ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded")
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
@@ -87,6 +90,8 @@ type BillingCacheService struct {
userRepo UserRepository
subRepo UserSubscriptionRepository
apiKeyRateLimitLoader apiKeyRateLimitLoader
userRPMCache UserRPMCache
userGroupRateRepo UserGroupRateRepository
cfg *config.Config
circuitBreaker *billingCircuitBreaker
@@ -104,12 +109,22 @@ type BillingCacheService struct {
}
// NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService {
func NewBillingCacheService(
cache BillingCache,
userRepo UserRepository,
subRepo UserSubscriptionRepository,
apiKeyRepo APIKeyRepository,
userRPMCache UserRPMCache,
userGroupRateRepo UserGroupRateRepository,
cfg *config.Config,
) *BillingCacheService {
svc := &BillingCacheService{
cache: cache,
userRepo: userRepo,
subRepo: subRepo,
apiKeyRateLimitLoader: apiKeyRepo,
userRPMCache: userRPMCache,
userGroupRateRepo: userGroupRateRepo,
cfg: cfg,
}
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
@@ -664,6 +679,95 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
}
}
// RPM 限流级联回落Override → Group → User放在最后以避免为注定失败的请求增加计数。
if err := s.checkRPM(ctx, user, group); err != nil {
return err
}
return nil
}
// checkRPM 执行并行 RPM 限流,所有适用的限制同时生效,任一超限即拒绝:
//
// 1. (用户, 分组) rpm_override — 最细粒度:管理员为特定用户在特定分组设定的专属限额。
// override=0 表示该用户在该分组免检(绿灯),但 user 级全局上限仍然生效。
// 2. group.rpm_limit — 分组级:该分组的统一 RPM 容量(仅当无 override 时生效)。
// 3. user.rpm_limit — 用户级全局硬上限:无论 override/group 如何配置,始终生效。
//
// 与旧版"级联互斥"设计不同,新版确保 user.rpm_limit 作为全局天花板不会被 group 或 override 覆盖。
// Redis 故障一律 fail-open打 warning不阻塞业务
func (s *BillingCacheService) checkRPM(ctx context.Context, user *User, group *Group) error {
if s == nil || s.userRPMCache == nil || user == nil {
return nil
}
// ── 第一层分组级检查override 或 group.rpm_limit ──
if group != nil {
// 解析 override优先从 auth cache snapshotnil 时回退 DB。
var override *int
if user.UserGroupRPMOverride != nil {
override = user.UserGroupRPMOverride
} else if s.userGroupRateRepo != nil {
dbOverride, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, user.ID, group.ID)
if err != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm override lookup failed for user=%d group=%d: %v",
user.ID, group.ID, err,
)
} else {
override = dbOverride
}
}
if override != nil {
// override=0 → 该用户在该分组免检(但 user 级仍会在下面检查)。
if *override > 0 {
count, incErr := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
if incErr != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm increment (override) failed for user=%d group=%d: %v",
user.ID, group.ID, incErr,
)
// fail-open
} else if count > *override {
return ErrGroupRPMExceeded
}
}
// override 命中后跳过 group.rpm_limitoverride 替代 group但不 return——继续检查 user 级。
} else if group.RPMLimit > 0 {
// 无 override检查 group.rpm_limit。
count, err := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
if err != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm increment (group) failed for user=%d group=%d: %v",
user.ID, group.ID, err,
)
// fail-open
} else if count > group.RPMLimit {
return ErrGroupRPMExceeded
}
}
}
// ── 第二层:用户级全局硬上限(始终生效) ──
if user.RPMLimit > 0 {
count, err := s.userRPMCache.IncrementUserRPM(ctx, user.ID)
if err != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm increment (user) failed for user=%d: %v",
user.ID, err,
)
return nil // fail-open
}
if count > user.RPMLimit {
return ErrUserRPMExceeded
}
}
return nil
}

View File

@@ -0,0 +1,253 @@
//go:build unit
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// userRPMCacheStub 记录每种计数器被调用的次数,并可注入返回值与错误。
type userRPMCacheStub struct {
userGroupCalls int32
userCalls int32
userGroupCounts []int // 依次返回的计数值
userGroupErr error
userCounts []int
userErr error
}
func (s *userRPMCacheStub) IncrementUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
idx := int(atomic.AddInt32(&s.userGroupCalls, 1)) - 1
if s.userGroupErr != nil {
return 0, s.userGroupErr
}
if idx < len(s.userGroupCounts) {
return s.userGroupCounts[idx], nil
}
return 1, nil
}
func (s *userRPMCacheStub) IncrementUserRPM(_ context.Context, _ int64) (int, error) {
idx := int(atomic.AddInt32(&s.userCalls, 1)) - 1
if s.userErr != nil {
return 0, s.userErr
}
if idx < len(s.userCounts) {
return s.userCounts[idx], nil
}
return 1, nil
}
func (s *userRPMCacheStub) GetUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
return 0, nil
}
func (s *userRPMCacheStub) GetUserRPM(_ context.Context, _ int64) (int, error) {
return 0, nil
}
// rpmOverrideRepoStub 专用于 checkRPM 分支测试,只实现必要方法。
type rpmOverrideRepoStub struct {
UserGroupRateRepository
override *int
err error
calls int32
}
func (s *rpmOverrideRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
atomic.AddInt32(&s.calls, 1)
if s.err != nil {
return nil, s.err
}
return s.override, nil
}
func newBillingServiceForRPM(t *testing.T, cache UserRPMCache, rateRepo UserGroupRateRepository) *BillingCacheService {
t.Helper()
// 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。
// 我们只直接测 checkRPM。
svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{})
t.Cleanup(svc.Stop)
return svc
}
func TestBillingCacheService_CheckRPM_OverrideTakesPrecedenceOverGroup(t *testing.T) {
override := 2
// user-group 计数: 1, 2, 3user 计数: 默认返回 1远小于 RPMLimit=100不干扰
cache := &userRPMCacheStub{userGroupCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{override: &override}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 100} // 全局上限设高,不干扰 override 测试
group := &Group{ID: 10, RPMLimit: 100}
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded)
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userGroupCalls), "override 命中分支应走 user-group 计数")
// 并行设计:前 2 次 override 未超→继续检查 user第 3 次 override 超了→直接 return不检查 user
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userCalls), "override 超限前 user 计数器应被调用")
require.EqualValues(t, 3, atomic.LoadInt32(&repo.calls))
}
func TestBillingCacheService_CheckRPM_UserLimitIsGlobalHardCap(t *testing.T) {
override := 100 // override 很高
// user-group 计数: 默认返回 1远小于 overrideuser 计数: 1, 2, 3
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{override: &override}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 2} // 全局硬上限=2应覆盖 override=100
group := &Group{ID: 10, RPMLimit: 100}
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded, "user 全局硬上限应优先于 override")
}
func TestBillingCacheService_CheckRPM_OverrideZeroSkipsGroupButUserStillApplies(t *testing.T) {
zero := 0
// user 计数: 依次返回 1..6
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3, 4, 5, 6}}
repo := &rpmOverrideRepoStub{override: &zero}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 5}
group := &Group{ID: 10, RPMLimit: 100}
// override=0 跳过分组计数,但 user.RPMLimit=5 仍生效
for i := 0; i < 5; i++ {
require.NoError(t, svc.checkRPM(context.Background(), user, group), "request %d should pass", i+1)
}
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded,
"override=0 跳过分组但 user 全局上限仍应生效")
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不应触发分组计数器")
require.EqualValues(t, 6, atomic.LoadInt32(&cache.userCalls), "user 计数器应被调用")
}
func TestBillingCacheService_CheckRPM_OverrideZeroAndUserZeroIsFullyUnlimited(t *testing.T) {
zero := 0
cache := &userRPMCacheStub{}
repo := &rpmOverrideRepoStub{override: &zero}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0} // user 也不限
group := &Group{ID: 10, RPMLimit: 100}
for i := 0; i < 50; i++ {
require.NoError(t, svc.checkRPM(context.Background(), user, group))
}
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不触发分组计数")
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls), "user.RPMLimit=0 也不触发用户计数")
}
func TestBillingCacheService_CheckRPM_NilOverrideFallsThroughToGroup(t *testing.T) {
// user-group 计数: 5, 6user 计数: 默认 1不干扰
cache := &userRPMCacheStub{userGroupCounts: []int{5, 6}}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 999} // 全局上限很高group 先超
group := &Group{ID: 10, RPMLimit: 5}
require.NoError(t, svc.checkRPM(context.Background(), user, group)) // ug=5, user=1, 都没超
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded) // ug=6 > 5
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userGroupCalls))
// 并行模式:第 1 次 group 没超 → 继续检查 user第 2 次 group 超了 → 直接 return不检查 user
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userCalls), "group 未超时 user 也应检查group 超时直接返回")
}
func TestBillingCacheService_CheckRPM_OverrideLookupErrorFallsThroughToGroup(t *testing.T) {
cache := &userRPMCacheStub{userGroupCounts: []int{3}}
repo := &rpmOverrideRepoStub{err: errors.New("db down")}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0}
group := &Group{ID: 10, RPMLimit: 10}
// override 查询失败后应继续尝试 group 分支(不直接拒绝)
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
require.EqualValues(t, 1, atomic.LoadInt32(&repo.calls))
}
func TestBillingCacheService_CheckRPM_UserLevelFallbackWhenGroupUnlimited(t *testing.T) {
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 2}
group := &Group{ID: 10, RPMLimit: 0} // 分组未设限
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded)
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "group 未设限时不应 INCR user-group 键")
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
}
func TestBillingCacheService_CheckRPM_NoLimitsConfiguredIsNoop(t *testing.T) {
cache := &userRPMCacheStub{}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0}
group := &Group{ID: 10, RPMLimit: 0}
for i := 0; i < 10; i++ {
require.NoError(t, svc.checkRPM(context.Background(), user, group))
}
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
}
func TestBillingCacheService_CheckRPM_RedisErrorFailOpen(t *testing.T) {
cache := &userRPMCacheStub{userGroupErr: errors.New("redis unavailable")}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0}
group := &Group{ID: 10, RPMLimit: 5}
// Redis 故障时应 fail-open不拒绝请求
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
}
func TestBillingCacheService_CheckRPM_NoGroupUsesUserOnly(t *testing.T) {
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 2}
// 无 group纯用户级限流场景不应查询 rpm_override。
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, nil), ErrUserRPMExceeded)
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls), "无 group 时不应查询 rpm_override")
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
}
func TestBillingCacheService_CheckRPM_NilUserIsNoop(t *testing.T) {
cache := &userRPMCacheStub{}
repo := &rpmOverrideRepoStub{}
svc := newBillingServiceForRPM(t, cache, repo)
require.NoError(t, svc.checkRPM(context.Background(), nil, &Group{ID: 1, RPMLimit: 10}))
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls))
}

View File

@@ -100,7 +100,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
delay: 80 * time.Millisecond,
balance: 12.34,
}
svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{})
svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{})
t.Cleanup(svc.Stop)
const goroutines = 16

View File

@@ -70,7 +70,7 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context,
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
t.Cleanup(svc.Stop)
start := time.Now()
@@ -92,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
svc.Stop()
enqueued := svc.enqueueCacheWrite(cacheWriteTask{

View File

@@ -170,9 +170,10 @@ const (
SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表JSON 数组)
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表JSON
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表JSON
SettingKeyDefaultUserRPMLimit = "default_user_rpm_limit" // 新用户默认 RPM 限制0 = 不限制)
// 第三方认证来源默认授予配置
SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"

View File

@@ -59,6 +59,10 @@ type Group struct {
DefaultMappedModel string
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
// RPMLimit 分组级每分钟请求数上限0 = 不限制)。
// 一旦设置即接管该分组用户的限流(覆盖用户级 rpm_limit可被 user-group rpm_override 进一步覆盖。
RPMLimit int
CreatedAt time.Time
UpdatedAt time.Time

View File

@@ -1060,6 +1060,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit)
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
if err != nil {
return nil, fmt.Errorf("marshal default subscriptions: %w", err)
@@ -1422,6 +1423,18 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
return s.cfg.Default.UserBalance
}
// GetDefaultUserRPMLimit 获取新用户默认 RPM 限制0 = 不限制)。未配置则返回 0。
func (s *SettingService) GetDefaultUserRPMLimit(ctx context.Context) int {
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultUserRPMLimit)
if err != nil || value == "" {
return 0
}
if v, err := strconv.Atoi(value); err == nil && v >= 0 {
return v
}
return 0
}
// GetDefaultSubscriptions 获取新用户默认订阅配置列表。
func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting {
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions)
@@ -1590,6 +1603,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyOIDCConnectUserInfoUsernamePath: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyDefaultUserRPMLimit: "0",
SettingKeyDefaultSubscriptions: "[]",
SettingKeyAuthSourceDefaultEmailBalance: "0",
SettingKeyAuthSourceDefaultEmailConcurrency: "5",
@@ -1699,6 +1713,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.DefaultConcurrency = s.cfg.Default.UserConcurrency
}
if rpm, err := strconv.Atoi(settings[SettingKeyDefaultUserRPMLimit]); err == nil && rpm >= 0 {
result.DefaultUserRPMLimit = rpm
}
// 解析浮点数类型
if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil {
result.DefaultBalance = balance

View File

@@ -106,6 +106,7 @@ type SystemSettings struct {
DefaultConcurrency int
DefaultBalance float64
DefaultUserRPMLimit int
DefaultSubscriptions []DefaultSubscriptionSetting
// Model fallback configuration

View File

@@ -49,6 +49,15 @@ type User struct {
BalanceNotifyExtraEmails []NotifyEmailEntry
TotalRecharged float64
// RPMLimit 用户级每分钟请求数上限0 = 不限制)。仅在所用分组未设置 rpm_limit
// 且该 (用户, 分组) 无 rpm_override 时作为全局兜底生效,计数键 rpm:u:{userID}:{min}。
RPMLimit int
// UserGroupRPMOverride 来自 auth cache snapshot 的 (user, group) RPM 覆盖值。
// nil = 该 API Key 对应的 (user, group) 无 override非 nil 时 checkRPM 直接使用,
// 避免每请求查 DB。字段不持久化到数据库。
UserGroupRPMOverride *int
APIKeys []APIKey
Subscriptions []UserSubscription
}

View File

@@ -2,14 +2,16 @@ package service
import "context"
// UserGroupRateEntry 分组下用户专属倍率条目
// UserGroupRateEntry 分组下用户专属倍率/RPM 条目
// RateMultiplier 与 RPMOverride 均为指针以支持"未设置"语义NULL
type UserGroupRateEntry struct {
UserID int64 `json:"user_id"`
UserName string `json:"user_name"`
UserEmail string `json:"user_email"`
UserNotes string `json:"user_notes"`
UserStatus string `json:"user_status"`
RateMultiplier float64 `json:"rate_multiplier"`
UserID int64 `json:"user_id"`
UserName string `json:"user_name"`
UserEmail string `json:"user_email"`
UserNotes string `json:"user_notes"`
UserStatus string `json:"user_status"`
RateMultiplier *float64 `json:"rate_multiplier,omitempty"`
RPMOverride *int `json:"rpm_override,omitempty"`
}
// GroupRateMultiplierInput 批量设置分组倍率的输入条目
@@ -18,30 +20,44 @@ type GroupRateMultiplierInput struct {
RateMultiplier float64 `json:"rate_multiplier"`
}
// UserGroupRateRepository 用户专属分组倍率仓储接口
// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
// GroupRPMOverrideInput 批量设置分组 RPM override 的输入条目。
// RPMOverride 为 *int 以支持清除nil语义。
type GroupRPMOverrideInput struct {
UserID int64 `json:"user_id"`
RPMOverride *int `json:"rpm_override"`
}
// UserGroupRateRepository 用户专属分组倍率/RPM 仓储接口。
// 允许管理员为特定用户设置分组的专属计费倍率与 RPM 上限,覆盖分组默认值。
type UserGroupRateRepository interface {
// GetByUserID 获取用户所有专属分组倍率
// 返回 map[groupID]rateMultiplier
// GetByUserID 获取用户所有专属分组 rate_multiplier仅返回非 NULL 的条目)
GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error)
// GetByUserAndGroup 获取用户在特定分组的专属倍率
// 如果未设置专属倍率,返回 nil
// GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplierNULL 返回 nil
GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error)
// GetByGroupID 获取指定分组下所有用户的专属倍率
// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_overrideNULL 返回 nil
GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error)
// GetByGroupID 获取指定分组下所有用户的专属配置rate 与 rpm_override 任一非 NULL 即返回)
GetByGroupID(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
// SyncUserGroupRates 同步用户的分组专属倍率
// rates: map[groupID]*rateMultipliernil 表示删除该分组的专属倍率
// SyncUserGroupRates 同步用户的分组专属倍率nil 表示清空该分组的 rate_multiplier
SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组数据
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组 rate 部分
SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
// DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用)
// SyncGroupRPMOverrides 批量同步分组的用户专属 RPM替换整组 rpm_override 部分)。
// 条目中 RPMOverride 为 nil 时清空对应行的 rpm_override非 nil 时 upsert。
SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error
// ClearGroupRPMOverrides 清空指定分组的所有 rpm_override整组 rpm 部分归 NULL
ClearGroupRPMOverrides(ctx context.Context, groupID int64) error
// DeleteByGroupID 删除指定分组的所有用户专属条目(分组删除时调用)
DeleteByGroupID(ctx context.Context, groupID int64) error
// DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用)
// DeleteByUserID 删除指定用户的所有专属条目(用户删除时调用)
DeleteByUserID(ctx context.Context, userID int64) error
}

View File

@@ -0,0 +1,25 @@
package service
import "context"
// UserRPMCache 用户/分组级 RPM 计数器接口。
//
// 与账号级 RPMCache 的区别:
// - RPMCache —— 按外部 AI provider 账号聚合key: rpm:{accountID}:{min})。
// - UserRPMCache —— 按用户或 (用户, 分组) 聚合,杜绝"同一用户创建多个 API Key 绕过 RPM"的路径。
// key 形如 rpm:ug:{userID}:{groupID}:{min} 或 rpm:u:{userID}:{min}。
type UserRPMCache interface {
// IncrementUserGroupRPM 原子递增 (user, group) 级分钟计数并返回最新值。
// 用于分组 rpm_limit 与 user-group rpm_override 两种命中分支。
IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (count int, err error)
// IncrementUserRPM 原子递增用户级分钟计数并返回最新值。
// 用于用户全局 rpm_limit 兜底分支(分组未设且无 override 时)。
IncrementUserRPM(ctx context.Context, userID int64) (count int, err error)
// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM只读不递增
GetUserGroupRPM(ctx context.Context, userID, groupID int64) (count int, err error)
// GetUserRPM 获取用户当前分钟已用 RPM只读不递增
GetUserRPM(ctx context.Context, userID int64) (count int, err error)
}

View File

@@ -39,6 +39,11 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
return NewEmailQueueService(emailService, 3)
}
// ProvideOAuthRefreshAPI creates OAuthRefreshAPI with the default lock TTL.
func ProvideOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI {
return NewOAuthRefreshAPI(accountRepo, tokenCache)
}
// ProvideTokenRefreshService creates and starts TokenRefreshService
func ProvideTokenRefreshService(
accountRepo AccountRepository,
@@ -383,6 +388,19 @@ func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupReposit
return svc
}
// ProvideBillingCacheService wires BillingCacheService with its RPM dependencies.
func ProvideBillingCacheService(
cache BillingCache,
userRepo UserRepository,
subRepo UserSubscriptionRepository,
apiKeyRepo APIKeyRepository,
rpmCache UserRPMCache,
rateRepo UserGroupRateRepository,
cfg *config.Config,
) *BillingCacheService {
return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg)
}
// ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet(
// Core services
@@ -399,7 +417,7 @@ var ProviderSet = wire.NewSet(
NewDashboardService,
ProvidePricingService,
NewBillingService,
NewBillingCacheService,
ProvideBillingCacheService,
NewAnnouncementService,
NewAdminService,
NewGatewayService,
@@ -411,7 +429,7 @@ var ProviderSet = wire.NewSet(
NewCompositeTokenCacheInvalidator,
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
NewAntigravityOAuthService,
NewOAuthRefreshAPI,
ProvideOAuthRefreshAPI,
ProvideGeminiTokenProvider,
NewGeminiMessagesCompatService,
ProvideAntigravityTokenProvider,