feat(channel): 渠道管理系统 — 多模式定价 + 统一计费解析
Cherry-picked from release/custom-0.1.106: a9117600
This commit is contained in:
@@ -371,13 +371,193 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
return nil, fmt.Errorf("pricing not found for model: %s", model)
|
||||
}
|
||||
|
||||
// GetModelPricingWithChannel 获取模型定价,渠道配置的价格覆盖默认值
|
||||
// 仅覆盖渠道中非 nil 的价格字段,nil 字段使用默认定价
|
||||
func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing *ChannelModelPricing) (*ModelPricing, error) {
|
||||
pricing, err := s.GetModelPricing(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if channelPricing == nil {
|
||||
return pricing, nil
|
||||
}
|
||||
if channelPricing.InputPrice != nil {
|
||||
pricing.InputPricePerToken = *channelPricing.InputPrice
|
||||
pricing.InputPricePerTokenPriority = *channelPricing.InputPrice
|
||||
}
|
||||
if channelPricing.OutputPrice != nil {
|
||||
pricing.OutputPricePerToken = *channelPricing.OutputPrice
|
||||
pricing.OutputPricePerTokenPriority = *channelPricing.OutputPrice
|
||||
}
|
||||
if channelPricing.CacheWritePrice != nil {
|
||||
pricing.CacheCreationPricePerToken = *channelPricing.CacheWritePrice
|
||||
pricing.CacheCreation5mPrice = *channelPricing.CacheWritePrice
|
||||
pricing.CacheCreation1hPrice = *channelPricing.CacheWritePrice
|
||||
}
|
||||
if channelPricing.CacheReadPrice != nil {
|
||||
pricing.CacheReadPricePerToken = *channelPricing.CacheReadPrice
|
||||
pricing.CacheReadPricePerTokenPriority = *channelPricing.CacheReadPrice
|
||||
}
|
||||
return pricing, nil
|
||||
}
|
||||
|
||||
// CalculateCostWithChannel 使用渠道定价计算费用
|
||||
// Deprecated: 使用 CalculateCostUnified 代替
|
||||
func (s *BillingService) CalculateCostWithChannel(model string, tokens UsageTokens, rateMultiplier float64, channelPricing *ChannelModelPricing) (*CostBreakdown, error) {
|
||||
return s.calculateCostInternal(model, tokens, rateMultiplier, "", channelPricing)
|
||||
}
|
||||
|
||||
// --- 统一计费入口 ---
|
||||
|
||||
// CostInput 统一计费输入
|
||||
type CostInput struct {
|
||||
Ctx context.Context
|
||||
Model string
|
||||
GroupID *int64 // 用于渠道定价查找
|
||||
Tokens UsageTokens
|
||||
RequestCount int // 按次计费时使用
|
||||
SizeTier string // 按次/图片模式的层级标签("1K","2K","4K","HD" 等)
|
||||
RateMultiplier float64
|
||||
ServiceTier string // "priority","flex","" 等
|
||||
Resolver *ModelPricingResolver // 定价解析器
|
||||
}
|
||||
|
||||
// CalculateCostUnified 统一计费入口,支持三种计费模式。
|
||||
// 使用 ModelPricingResolver 解析定价,然后根据 BillingMode 分发计算。
|
||||
func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, error) {
|
||||
if input.Resolver == nil {
|
||||
// 无 Resolver,回退到旧路径
|
||||
return s.calculateCostInternal(input.Model, input.Tokens, input.RateMultiplier, input.ServiceTier, nil)
|
||||
}
|
||||
|
||||
resolved := input.Resolver.Resolve(input.Ctx, PricingInput{
|
||||
Model: input.Model,
|
||||
GroupID: input.GroupID,
|
||||
})
|
||||
|
||||
if input.RateMultiplier <= 0 {
|
||||
input.RateMultiplier = 1.0
|
||||
}
|
||||
|
||||
switch resolved.Mode {
|
||||
case BillingModePerRequest, BillingModeImage:
|
||||
return s.calculatePerRequestCost(resolved, input)
|
||||
default: // BillingModeToken
|
||||
return s.calculateTokenCost(resolved, input)
|
||||
}
|
||||
}
|
||||
|
||||
// calculateTokenCost 按 token 区间计费
|
||||
func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) {
|
||||
totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens
|
||||
|
||||
pricing := input.Resolver.GetIntervalPricing(resolved, totalContext)
|
||||
if pricing == nil {
|
||||
return nil, fmt.Errorf("no pricing available for model: %s", input.Model)
|
||||
}
|
||||
|
||||
pricing = s.applyModelSpecificPricingPolicy(input.Model, pricing)
|
||||
|
||||
breakdown := &CostBreakdown{}
|
||||
inputPricePerToken := pricing.InputPricePerToken
|
||||
outputPricePerToken := pricing.OutputPricePerToken
|
||||
cacheReadPricePerToken := pricing.CacheReadPricePerToken
|
||||
tierMultiplier := 1.0
|
||||
|
||||
if usePriorityServiceTierPricing(input.ServiceTier, pricing) {
|
||||
if pricing.InputPricePerTokenPriority > 0 {
|
||||
inputPricePerToken = pricing.InputPricePerTokenPriority
|
||||
}
|
||||
if pricing.OutputPricePerTokenPriority > 0 {
|
||||
outputPricePerToken = pricing.OutputPricePerTokenPriority
|
||||
}
|
||||
if pricing.CacheReadPricePerTokenPriority > 0 {
|
||||
cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority
|
||||
}
|
||||
} else {
|
||||
tierMultiplier = serviceTierCostMultiplier(input.ServiceTier)
|
||||
}
|
||||
|
||||
// 长上下文定价(仅在无区间定价时应用,区间定价已包含上下文分层)
|
||||
if len(resolved.Intervals) == 0 && s.shouldApplySessionLongContextPricing(input.Tokens, pricing) {
|
||||
inputPricePerToken *= pricing.LongContextInputMultiplier
|
||||
outputPricePerToken *= pricing.LongContextOutputMultiplier
|
||||
}
|
||||
|
||||
breakdown.InputCost = float64(input.Tokens.InputTokens) * inputPricePerToken
|
||||
breakdown.OutputCost = float64(input.Tokens.OutputTokens) * outputPricePerToken
|
||||
|
||||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||||
if input.Tokens.CacheCreation5mTokens == 0 && input.Tokens.CacheCreation1hTokens == 0 && input.Tokens.CacheCreationTokens > 0 {
|
||||
breakdown.CacheCreationCost = float64(input.Tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
|
||||
} else {
|
||||
breakdown.CacheCreationCost = float64(input.Tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
|
||||
float64(input.Tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
|
||||
}
|
||||
} else {
|
||||
breakdown.CacheCreationCost = float64(input.Tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
||||
}
|
||||
|
||||
breakdown.CacheReadCost = float64(input.Tokens.CacheReadTokens) * cacheReadPricePerToken
|
||||
|
||||
if tierMultiplier != 1.0 {
|
||||
breakdown.InputCost *= tierMultiplier
|
||||
breakdown.OutputCost *= tierMultiplier
|
||||
breakdown.CacheCreationCost *= tierMultiplier
|
||||
breakdown.CacheReadCost *= tierMultiplier
|
||||
}
|
||||
|
||||
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
|
||||
breakdown.CacheCreationCost + breakdown.CacheReadCost
|
||||
breakdown.ActualCost = breakdown.TotalCost * input.RateMultiplier
|
||||
|
||||
return breakdown, nil
|
||||
}
|
||||
|
||||
// calculatePerRequestCost 按次/图片计费
|
||||
func (s *BillingService) calculatePerRequestCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) {
|
||||
count := input.RequestCount
|
||||
if count <= 0 {
|
||||
count = 1
|
||||
}
|
||||
|
||||
var unitPrice float64
|
||||
|
||||
if input.SizeTier != "" {
|
||||
unitPrice = input.Resolver.GetRequestTierPrice(resolved, input.SizeTier)
|
||||
}
|
||||
|
||||
if unitPrice == 0 {
|
||||
totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens
|
||||
unitPrice = input.Resolver.GetRequestTierPriceByContext(resolved, totalContext)
|
||||
}
|
||||
|
||||
totalCost := unitPrice * float64(count)
|
||||
actualCost := totalCost * input.RateMultiplier
|
||||
|
||||
return &CostBreakdown{
|
||||
TotalCost: totalCost,
|
||||
ActualCost: actualCost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CalculateCost 计算使用费用
|
||||
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
|
||||
return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "")
|
||||
return s.calculateCostInternal(model, tokens, rateMultiplier, "", nil)
|
||||
}
|
||||
|
||||
func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) {
|
||||
pricing, err := s.GetModelPricing(model)
|
||||
return s.calculateCostInternal(model, tokens, rateMultiplier, serviceTier, nil)
|
||||
}
|
||||
|
||||
func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string, channelPricing *ChannelModelPricing) (*CostBreakdown, error) {
|
||||
var pricing *ModelPricing
|
||||
var err error
|
||||
if channelPricing != nil {
|
||||
pricing, err = s.GetModelPricingWithChannel(model, channelPricing)
|
||||
} else {
|
||||
pricing, err = s.GetModelPricing(model)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
171
backend/internal/service/channel.go
Normal file
171
backend/internal/service/channel.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BillingMode 计费模式
|
||||
type BillingMode string
|
||||
|
||||
const (
|
||||
BillingModeToken BillingMode = "token" // 按 token 区间计费
|
||||
BillingModePerRequest BillingMode = "per_request" // 按次计费(支持上下文窗口分层)
|
||||
BillingModeImage BillingMode = "image" // 图片计费(当前按次,预留 token 计费)
|
||||
)
|
||||
|
||||
// IsValid 检查 BillingMode 是否为合法值
|
||||
func (m BillingMode) IsValid() bool {
|
||||
switch m {
|
||||
case BillingModeToken, BillingModePerRequest, BillingModeImage, "":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Channel 渠道实体
|
||||
type Channel struct {
|
||||
ID int64
|
||||
Name string
|
||||
Description string
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
// 关联的分组 ID 列表
|
||||
GroupIDs []int64
|
||||
// 模型定价列表
|
||||
ModelPricing []ChannelModelPricing
|
||||
}
|
||||
|
||||
// ChannelModelPricing 渠道模型定价条目
|
||||
type ChannelModelPricing struct {
|
||||
ID int64
|
||||
ChannelID int64
|
||||
Models []string // 绑定的模型列表
|
||||
BillingMode BillingMode // 计费模式
|
||||
InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价
|
||||
OutputPrice *float64 // 每 token 输出价格(USD)
|
||||
CacheWritePrice *float64 // 缓存写入价格
|
||||
CacheReadPrice *float64 // 缓存读取价格
|
||||
ImageOutputPrice *float64 // 图片输出价格(向后兼容)
|
||||
Intervals []PricingInterval // 区间定价列表
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// PricingInterval 定价区间(token 区间 / 按次分层 / 图片分辨率分层)
|
||||
type PricingInterval struct {
|
||||
ID int64
|
||||
PricingID int64
|
||||
MinTokens int // 区间下界(含)
|
||||
MaxTokens *int // 区间上界(不含),nil = 无上限
|
||||
TierLabel string // 层级标签(按次/图片模式:1K, 2K, 4K, HD 等)
|
||||
InputPrice *float64 // token 模式:每 token 输入价
|
||||
OutputPrice *float64 // token 模式:每 token 输出价
|
||||
CacheWritePrice *float64 // token 模式:缓存写入价
|
||||
CacheReadPrice *float64 // token 模式:缓存读取价
|
||||
PerRequestPrice *float64 // 按次/图片模式:每次请求价格
|
||||
SortOrder int
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// IsActive 判断渠道是否启用
|
||||
func (c *Channel) IsActive() bool {
|
||||
return c.Status == StatusActive
|
||||
}
|
||||
|
||||
// GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。
|
||||
// 优先精确匹配,然后通配符匹配(如 claude-opus-*)。大小写不敏感。
|
||||
// 返回值拷贝,不污染缓存。
|
||||
func (c *Channel) GetModelPricing(model string) *ChannelModelPricing {
|
||||
modelLower := strings.ToLower(model)
|
||||
|
||||
// 第一轮:精确匹配
|
||||
for i := range c.ModelPricing {
|
||||
for _, m := range c.ModelPricing[i].Models {
|
||||
if strings.ToLower(m) == modelLower {
|
||||
cp := c.ModelPricing[i].Clone()
|
||||
return &cp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 第二轮:通配符匹配(仅支持末尾 *)
|
||||
for i := range c.ModelPricing {
|
||||
for _, m := range c.ModelPricing[i].Models {
|
||||
mLower := strings.ToLower(m)
|
||||
if strings.HasSuffix(mLower, "*") {
|
||||
prefix := strings.TrimSuffix(mLower, "*")
|
||||
if strings.HasPrefix(modelLower, prefix) {
|
||||
cp := c.ModelPricing[i].Clone()
|
||||
return &cp
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindMatchingInterval 在区间列表中查找匹配 totalTokens 的区间。
|
||||
// 通用辅助函数,供 GetIntervalForContext、ModelPricingResolver 等复用。
|
||||
func FindMatchingInterval(intervals []PricingInterval, totalTokens int) *PricingInterval {
|
||||
for i := range intervals {
|
||||
iv := &intervals[i]
|
||||
if totalTokens >= iv.MinTokens && (iv.MaxTokens == nil || totalTokens < *iv.MaxTokens) {
|
||||
return iv
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetIntervalForContext 根据总 context token 数查找匹配的区间。
|
||||
func (p *ChannelModelPricing) GetIntervalForContext(totalTokens int) *PricingInterval {
|
||||
return FindMatchingInterval(p.Intervals, totalTokens)
|
||||
}
|
||||
|
||||
// GetTierByLabel 根据标签查找层级(用于 per_request / image 模式)
|
||||
func (p *ChannelModelPricing) GetTierByLabel(label string) *PricingInterval {
|
||||
labelLower := strings.ToLower(label)
|
||||
for i := range p.Intervals {
|
||||
if strings.ToLower(p.Intervals[i].TierLabel) == labelLower {
|
||||
return &p.Intervals[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clone 返回 ChannelModelPricing 的拷贝(切片独立,指针字段共享,调用方只读安全)
|
||||
func (p ChannelModelPricing) Clone() ChannelModelPricing {
|
||||
cp := p
|
||||
if p.Models != nil {
|
||||
cp.Models = make([]string, len(p.Models))
|
||||
copy(cp.Models, p.Models)
|
||||
}
|
||||
if p.Intervals != nil {
|
||||
cp.Intervals = make([]PricingInterval, len(p.Intervals))
|
||||
copy(cp.Intervals, p.Intervals)
|
||||
}
|
||||
return cp
|
||||
}
|
||||
|
||||
// Clone 返回 Channel 的深拷贝
|
||||
func (c *Channel) Clone() *Channel {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
cp := *c
|
||||
if c.GroupIDs != nil {
|
||||
cp.GroupIDs = make([]int64, len(c.GroupIDs))
|
||||
copy(cp.GroupIDs, c.GroupIDs)
|
||||
}
|
||||
if c.ModelPricing != nil {
|
||||
cp.ModelPricing = make([]ChannelModelPricing, len(c.ModelPricing))
|
||||
for i := range c.ModelPricing {
|
||||
cp.ModelPricing[i] = c.ModelPricing[i].Clone()
|
||||
}
|
||||
}
|
||||
return &cp
|
||||
}
|
||||
338
backend/internal/service/channel_service.go
Normal file
338
backend/internal/service/channel_service.go
Normal file
@@ -0,0 +1,338 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found")
|
||||
ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists")
|
||||
ErrGroupAlreadyInChannel = infraerrors.Conflict(
|
||||
"GROUP_ALREADY_IN_CHANNEL",
|
||||
"one or more groups already belong to another channel",
|
||||
)
|
||||
)
|
||||
|
||||
// ChannelRepository 渠道数据访问接口
|
||||
type ChannelRepository interface {
|
||||
Create(ctx context.Context, channel *Channel) error
|
||||
GetByID(ctx context.Context, id int64) (*Channel, error)
|
||||
Update(ctx context.Context, channel *Channel) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error)
|
||||
ListAll(ctx context.Context) ([]Channel, error)
|
||||
ExistsByName(ctx context.Context, name string) (bool, error)
|
||||
ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error)
|
||||
|
||||
// 分组关联
|
||||
GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error)
|
||||
SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error
|
||||
GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error)
|
||||
|
||||
// 模型定价
|
||||
ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error)
|
||||
CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error
|
||||
UpdateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error
|
||||
DeleteModelPricing(ctx context.Context, id int64) error
|
||||
ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error
|
||||
}
|
||||
|
||||
// channelCache 渠道缓存快照
|
||||
type channelCache struct {
|
||||
// byID: channelID -> *Channel(含 ModelPricing)
|
||||
byID map[int64]*Channel
|
||||
// byGroupID: groupID -> channelID
|
||||
byGroupID map[int64]int64
|
||||
loadedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
channelCacheTTL = 60 * time.Second
|
||||
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
|
||||
channelCacheDBTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// ChannelService 渠道管理服务
|
||||
type ChannelService struct {
|
||||
repo ChannelRepository
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
|
||||
cache atomic.Value // *channelCache
|
||||
cacheSF singleflight.Group
|
||||
}
|
||||
|
||||
// NewChannelService 创建渠道服务实例
|
||||
func NewChannelService(repo ChannelRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService {
|
||||
s := &ChannelService{
|
||||
repo: repo,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// loadCache 加载或返回缓存的渠道数据
|
||||
func (s *ChannelService) loadCache(ctx context.Context) (*channelCache, error) {
|
||||
if cached, ok := s.cache.Load().(*channelCache); ok {
|
||||
if time.Since(cached.loadedAt) < channelCacheTTL {
|
||||
return cached, nil
|
||||
}
|
||||
}
|
||||
|
||||
result, err, _ := s.cacheSF.Do("channel_cache", func() (any, error) {
|
||||
// 双重检查
|
||||
if cached, ok := s.cache.Load().(*channelCache); ok {
|
||||
if time.Since(cached.loadedAt) < channelCacheTTL {
|
||||
return cached, nil
|
||||
}
|
||||
}
|
||||
return s.buildCache(ctx)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.(*channelCache), nil
|
||||
}
|
||||
|
||||
// buildCache 从数据库构建渠道缓存。
|
||||
// 使用独立 context 避免请求取消导致空值被长期缓存。
|
||||
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
|
||||
// 断开请求取消链,避免客户端断连导致空值被长期缓存
|
||||
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
|
||||
defer cancel()
|
||||
|
||||
channels, err := s.repo.ListAll(dbCtx)
|
||||
if err != nil {
|
||||
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
||||
slog.Warn("failed to build channel cache", "error", err)
|
||||
errorCache := &channelCache{
|
||||
byID: make(map[int64]*Channel),
|
||||
byGroupID: make(map[int64]int64),
|
||||
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
|
||||
}
|
||||
s.cache.Store(errorCache)
|
||||
return nil, fmt.Errorf("list all channels: %w", err)
|
||||
}
|
||||
|
||||
cache := &channelCache{
|
||||
byID: make(map[int64]*Channel, len(channels)),
|
||||
byGroupID: make(map[int64]int64),
|
||||
loadedAt: time.Now(),
|
||||
}
|
||||
|
||||
for i := range channels {
|
||||
ch := &channels[i]
|
||||
cache.byID[ch.ID] = ch
|
||||
for _, gid := range ch.GroupIDs {
|
||||
cache.byGroupID[gid] = ch.ID
|
||||
}
|
||||
}
|
||||
|
||||
s.cache.Store(cache)
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
// invalidateCache 使缓存失效,让下次读取时自然重建
|
||||
func (s *ChannelService) invalidateCache() {
|
||||
s.cache.Store((*channelCache)(nil))
|
||||
s.cacheSF.Forget("channel_cache")
|
||||
}
|
||||
|
||||
// GetChannelForGroup 获取分组关联的渠道(热路径,从缓存读取)
|
||||
// 返回深拷贝,不污染缓存。
|
||||
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
|
||||
cache, err := s.loadCache(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
channelID, ok := cache.byGroupID[groupID]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ch, ok := cache.byID[channelID]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if !ch.IsActive() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return ch.Clone(), nil
|
||||
}
|
||||
|
||||
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径)
|
||||
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
|
||||
ch, err := s.GetChannelForGroup(ctx, groupID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get channel for group", "group_id", groupID, "error", err)
|
||||
return nil
|
||||
}
|
||||
if ch == nil {
|
||||
return nil
|
||||
}
|
||||
return ch.GetModelPricing(model)
|
||||
}
|
||||
|
||||
// --- CRUD ---
|
||||
|
||||
// Create 创建渠道
|
||||
func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) (*Channel, error) {
|
||||
exists, err := s.repo.ExistsByName(ctx, input.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check channel exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrChannelExists
|
||||
}
|
||||
|
||||
// 检查分组冲突
|
||||
if len(input.GroupIDs) > 0 {
|
||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check group conflicts: %w", err)
|
||||
}
|
||||
if len(conflicting) > 0 {
|
||||
return nil, ErrGroupAlreadyInChannel
|
||||
}
|
||||
}
|
||||
|
||||
channel := &Channel{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Status: StatusActive,
|
||||
GroupIDs: input.GroupIDs,
|
||||
ModelPricing: input.ModelPricing,
|
||||
}
|
||||
|
||||
if err := s.repo.Create(ctx, channel); err != nil {
|
||||
return nil, fmt.Errorf("create channel: %w", err)
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
return s.repo.GetByID(ctx, channel.ID)
|
||||
}
|
||||
|
||||
// GetByID 获取渠道详情
|
||||
func (s *ChannelService) GetByID(ctx context.Context, id int64) (*Channel, error) {
|
||||
return s.repo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// Update 更新渠道
|
||||
func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChannelInput) (*Channel, error) {
|
||||
channel, err := s.repo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get channel: %w", err)
|
||||
}
|
||||
|
||||
if input.Name != "" && input.Name != channel.Name {
|
||||
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check channel exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrChannelExists
|
||||
}
|
||||
channel.Name = input.Name
|
||||
}
|
||||
|
||||
if input.Description != nil {
|
||||
channel.Description = *input.Description
|
||||
}
|
||||
|
||||
if input.Status != "" {
|
||||
channel.Status = input.Status
|
||||
}
|
||||
|
||||
// 检查分组冲突
|
||||
if input.GroupIDs != nil {
|
||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check group conflicts: %w", err)
|
||||
}
|
||||
if len(conflicting) > 0 {
|
||||
return nil, ErrGroupAlreadyInChannel
|
||||
}
|
||||
channel.GroupIDs = *input.GroupIDs
|
||||
}
|
||||
|
||||
if input.ModelPricing != nil {
|
||||
channel.ModelPricing = *input.ModelPricing
|
||||
}
|
||||
|
||||
if err := s.repo.Update(ctx, channel); err != nil {
|
||||
return nil, fmt.Errorf("update channel: %w", err)
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
|
||||
// 失效关联分组的 auth 缓存
|
||||
if s.authCacheInvalidator != nil {
|
||||
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get group IDs for cache invalidation", "channel_id", id, "error", err)
|
||||
}
|
||||
for _, gid := range groupIDs {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
|
||||
return s.repo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// Delete 删除渠道
|
||||
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
|
||||
// 先获取关联分组用于失效缓存
|
||||
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
|
||||
}
|
||||
|
||||
if err := s.repo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete channel: %w", err)
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
for _, gid := range groupIDs {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 获取渠道列表
|
||||
func (s *ChannelService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) {
|
||||
return s.repo.List(ctx, params, status, search)
|
||||
}
|
||||
|
||||
// --- Input types ---
|
||||
|
||||
// CreateChannelInput 创建渠道输入
|
||||
type CreateChannelInput struct {
|
||||
Name string
|
||||
Description string
|
||||
GroupIDs []int64
|
||||
ModelPricing []ChannelModelPricing
|
||||
}
|
||||
|
||||
// UpdateChannelInput 更新渠道输入
|
||||
type UpdateChannelInput struct {
|
||||
Name string
|
||||
Description *string
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ModelPricing *[]ChannelModelPricing
|
||||
}
|
||||
210
backend/internal/service/channel_test.go
Normal file
210
backend/internal/service/channel_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func channelTestPtrFloat64(v float64) *float64 { return &v }
|
||||
func channelTestPtrInt(v int) *int { return &v }
|
||||
|
||||
func TestGetModelPricing(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(3e-6)},
|
||||
{ID: 2, Models: []string{"claude-*"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(5e-6)},
|
||||
{ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantID int64
|
||||
wantNil bool
|
||||
}{
|
||||
{"exact match", "claude-sonnet-4", 1, false},
|
||||
{"case insensitive", "Claude-Sonnet-4", 1, false},
|
||||
{"wildcard match", "claude-opus-4-20250514", 2, false},
|
||||
{"exact takes priority over wildcard", "claude-sonnet-4", 1, false},
|
||||
{"not found", "gemini-3.1-pro", 0, true},
|
||||
{"per_request model", "gpt-5.1", 3, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ch.GetModelPricing(tt.model)
|
||||
if tt.wantNil {
|
||||
require.Nil(t, result)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, tt.wantID, result.ID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelPricing_ReturnsCopy(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: channelTestPtrFloat64(3e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
result := ch.GetModelPricing("claude-sonnet-4")
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Modify the returned copy's slice — original should be unchanged
|
||||
result.Models = append(result.Models, "hacked")
|
||||
|
||||
// Original should be unchanged
|
||||
require.Equal(t, 1, len(ch.ModelPricing[0].Models))
|
||||
}
|
||||
|
||||
func TestGetModelPricing_EmptyPricing(t *testing.T) {
|
||||
ch := &Channel{ModelPricing: nil}
|
||||
require.Nil(t, ch.GetModelPricing("any-model"))
|
||||
|
||||
ch2 := &Channel{ModelPricing: []ChannelModelPricing{}}
|
||||
require.Nil(t, ch2.GetModelPricing("any-model"))
|
||||
}
|
||||
|
||||
func TestGetIntervalForContext(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: channelTestPtrInt(128000), InputPrice: channelTestPtrFloat64(1e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: channelTestPtrFloat64(2e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens int
|
||||
wantPrice *float64
|
||||
wantNil bool
|
||||
}{
|
||||
{"first interval", 50000, channelTestPtrFloat64(1e-6), false},
|
||||
{"boundary: at min of second", 128000, channelTestPtrFloat64(2e-6), false},
|
||||
{"boundary: at max of first (exclusive)", 128000, channelTestPtrFloat64(2e-6), false},
|
||||
{"unbounded interval", 500000, channelTestPtrFloat64(2e-6), false},
|
||||
{"zero tokens", 0, channelTestPtrFloat64(1e-6), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := p.GetIntervalForContext(tt.tokens)
|
||||
if tt.wantNil {
|
||||
require.Nil(t, result)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, *tt.wantPrice, *result.InputPrice, 1e-12)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetIntervalForContext_NoMatch(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 10000, MaxTokens: channelTestPtrInt(50000)},
|
||||
},
|
||||
}
|
||||
require.Nil(t, p.GetIntervalForContext(5000))
|
||||
require.Nil(t, p.GetIntervalForContext(50000))
|
||||
}
|
||||
|
||||
func TestGetIntervalForContext_Empty(t *testing.T) {
|
||||
p := &ChannelModelPricing{Intervals: nil}
|
||||
require.Nil(t, p.GetIntervalForContext(1000))
|
||||
}
|
||||
|
||||
func TestGetTierByLabel(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: channelTestPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: channelTestPtrFloat64(0.08)},
|
||||
{TierLabel: "HD", PerRequestPrice: channelTestPtrFloat64(0.12)},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
label string
|
||||
wantNil bool
|
||||
want float64
|
||||
}{
|
||||
{"exact match", "1K", false, 0.04},
|
||||
{"case insensitive", "hd", false, 0.12},
|
||||
{"not found", "4K", true, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := p.GetTierByLabel(tt.label)
|
||||
if tt.wantNil {
|
||||
require.Nil(t, result)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, tt.want, *result.PerRequestPrice, 1e-12)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTierByLabel_Empty(t *testing.T) {
|
||||
p := &ChannelModelPricing{Intervals: nil}
|
||||
require.Nil(t, p.GetTierByLabel("1K"))
|
||||
}
|
||||
|
||||
func TestChannelClone(t *testing.T) {
|
||||
original := &Channel{
|
||||
ID: 1,
|
||||
Name: "test",
|
||||
GroupIDs: []int64{10, 20},
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{
|
||||
ID: 100,
|
||||
Models: []string{"model-a"},
|
||||
InputPrice: channelTestPtrFloat64(5e-6),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cloned := original.Clone()
|
||||
require.NotNil(t, cloned)
|
||||
require.Equal(t, original.ID, cloned.ID)
|
||||
require.Equal(t, original.Name, cloned.Name)
|
||||
|
||||
// Modify clone slices — original should not change
|
||||
cloned.GroupIDs[0] = 999
|
||||
require.Equal(t, int64(10), original.GroupIDs[0])
|
||||
|
||||
cloned.ModelPricing[0].Models[0] = "hacked"
|
||||
require.Equal(t, "model-a", original.ModelPricing[0].Models[0])
|
||||
}
|
||||
|
||||
func TestChannelClone_Nil(t *testing.T) {
|
||||
var ch *Channel
|
||||
require.Nil(t, ch.Clone())
|
||||
}
|
||||
|
||||
func TestChannelModelPricingClone(t *testing.T) {
|
||||
original := ChannelModelPricing{
|
||||
Models: []string{"a", "b"},
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, TierLabel: "tier1"},
|
||||
},
|
||||
}
|
||||
|
||||
cloned := original.Clone()
|
||||
|
||||
// Modify clone slices — original unchanged
|
||||
cloned.Models[0] = "hacked"
|
||||
require.Equal(t, "a", original.Models[0])
|
||||
|
||||
cloned.Intervals[0].TierLabel = "hacked"
|
||||
require.Equal(t, "tier1", original.Intervals[0].TierLabel)
|
||||
}
|
||||
@@ -41,6 +41,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -568,6 +568,7 @@ type GatewayService struct {
|
||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||
debugModelRouting atomic.Bool
|
||||
debugClaudeMimic atomic.Bool
|
||||
channelService *ChannelService
|
||||
debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
|
||||
tlsFPProfileService *TLSFingerprintProfileService
|
||||
}
|
||||
@@ -597,6 +598,7 @@ func NewGatewayService(
|
||||
digestStore *DigestSessionStore,
|
||||
settingService *SettingService,
|
||||
tlsFPProfileService *TLSFingerprintProfileService,
|
||||
channelService *ChannelService,
|
||||
) *GatewayService {
|
||||
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
|
||||
modelsListTTL := resolveModelsListCacheTTL(cfg)
|
||||
@@ -629,6 +631,7 @@ func NewGatewayService(
|
||||
modelsListCacheTTL: modelsListTTL,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
tlsFPProfileService: tlsFPProfileService,
|
||||
channelService: channelService,
|
||||
}
|
||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||
userGroupRateRepo,
|
||||
@@ -7771,7 +7774,16 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
||||
// 渠道定价覆盖
|
||||
var chPricing *ChannelModelPricing
|
||||
if s.channelService != nil && apiKey.Group != nil {
|
||||
chPricing = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel)
|
||||
}
|
||||
if chPricing != nil {
|
||||
cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing)
|
||||
} else {
|
||||
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
||||
}
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
@@ -7959,7 +7971,16 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||
// 渠道定价覆盖
|
||||
var chPricing2 *ChannelModelPricing
|
||||
if s.channelService != nil && apiKey.Group != nil {
|
||||
chPricing2 = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel)
|
||||
}
|
||||
if chPricing2 != nil {
|
||||
cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing2)
|
||||
} else {
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||
}
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
|
||||
198
backend/internal/service/model_pricing_resolver.go
Normal file
198
backend/internal/service/model_pricing_resolver.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// ResolvedPricing 统一定价解析结果
|
||||
type ResolvedPricing struct {
|
||||
// Mode 计费模式
|
||||
Mode BillingMode
|
||||
|
||||
// Token 模式:基础定价(来自 LiteLLM 或 fallback)
|
||||
BasePricing *ModelPricing
|
||||
|
||||
// Token 模式:区间定价列表(如有,覆盖 BasePricing 中的对应字段)
|
||||
Intervals []PricingInterval
|
||||
|
||||
// 按次/图片模式:分层定价
|
||||
RequestTiers []PricingInterval
|
||||
|
||||
// 来源标识
|
||||
Source string // "channel", "litellm", "fallback"
|
||||
|
||||
// 是否支持缓存细分
|
||||
SupportsCacheBreakdown bool
|
||||
}
|
||||
|
||||
// ModelPricingResolver 统一模型定价解析器。
|
||||
// 解析链:Channel → LiteLLM → Fallback。
|
||||
type ModelPricingResolver struct {
|
||||
channelService *ChannelService
|
||||
billingService *BillingService
|
||||
}
|
||||
|
||||
// NewModelPricingResolver 创建定价解析器实例
|
||||
func NewModelPricingResolver(channelService *ChannelService, billingService *BillingService) *ModelPricingResolver {
|
||||
return &ModelPricingResolver{
|
||||
channelService: channelService,
|
||||
billingService: billingService,
|
||||
}
|
||||
}
|
||||
|
||||
// PricingInput 定价解析输入
|
||||
type PricingInput struct {
|
||||
Model string
|
||||
GroupID *int64 // nil 表示不检查渠道
|
||||
}
|
||||
|
||||
// Resolve 解析模型定价。
|
||||
// 1. 获取基础定价(LiteLLM → Fallback)
|
||||
// 2. 如果指定了 GroupID,查找渠道定价并覆盖
|
||||
func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput) *ResolvedPricing {
|
||||
// 1. 获取基础定价
|
||||
basePricing, source := r.resolveBasePricing(input.Model)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModeToken,
|
||||
BasePricing: basePricing,
|
||||
Source: source,
|
||||
SupportsCacheBreakdown: basePricing != nil && basePricing.SupportsCacheBreakdown,
|
||||
}
|
||||
|
||||
// 2. 如果有 GroupID,尝试渠道覆盖
|
||||
if input.GroupID != nil {
|
||||
r.applyChannelOverrides(ctx, *input.GroupID, input.Model, resolved)
|
||||
}
|
||||
|
||||
return resolved
|
||||
}
|
||||
|
||||
// resolveBasePricing 从 LiteLLM 或 Fallback 获取基础定价
|
||||
func (r *ModelPricingResolver) resolveBasePricing(model string) (*ModelPricing, string) {
|
||||
pricing, err := r.billingService.GetModelPricing(model)
|
||||
if err != nil {
|
||||
slog.Debug("failed to get model pricing from LiteLLM, using fallback",
|
||||
"model", model, "error", err)
|
||||
return nil, "fallback"
|
||||
}
|
||||
return pricing, "litellm"
|
||||
}
|
||||
|
||||
// applyChannelOverrides 应用渠道定价覆盖
|
||||
func (r *ModelPricingResolver) applyChannelOverrides(ctx context.Context, groupID int64, model string, resolved *ResolvedPricing) {
|
||||
chPricing := r.channelService.GetChannelModelPricing(ctx, groupID, model)
|
||||
if chPricing == nil {
|
||||
return
|
||||
}
|
||||
|
||||
resolved.Source = "channel"
|
||||
resolved.Mode = chPricing.BillingMode
|
||||
if resolved.Mode == "" {
|
||||
resolved.Mode = BillingModeToken
|
||||
}
|
||||
|
||||
switch resolved.Mode {
|
||||
case BillingModeToken:
|
||||
r.applyTokenOverrides(chPricing, resolved)
|
||||
case BillingModePerRequest, BillingModeImage:
|
||||
r.applyRequestTierOverrides(chPricing, resolved)
|
||||
}
|
||||
}
|
||||
|
||||
// applyTokenOverrides 应用 token 模式的渠道覆盖
|
||||
func (r *ModelPricingResolver) applyTokenOverrides(chPricing *ChannelModelPricing, resolved *ResolvedPricing) {
|
||||
// 如果有区间定价,使用区间
|
||||
if len(chPricing.Intervals) > 0 {
|
||||
resolved.Intervals = chPricing.Intervals
|
||||
return
|
||||
}
|
||||
|
||||
// 否则用 flat 字段覆盖 BasePricing
|
||||
if resolved.BasePricing == nil {
|
||||
resolved.BasePricing = &ModelPricing{}
|
||||
}
|
||||
|
||||
if chPricing.InputPrice != nil {
|
||||
resolved.BasePricing.InputPricePerToken = *chPricing.InputPrice
|
||||
resolved.BasePricing.InputPricePerTokenPriority = *chPricing.InputPrice
|
||||
}
|
||||
if chPricing.OutputPrice != nil {
|
||||
resolved.BasePricing.OutputPricePerToken = *chPricing.OutputPrice
|
||||
resolved.BasePricing.OutputPricePerTokenPriority = *chPricing.OutputPrice
|
||||
}
|
||||
if chPricing.CacheWritePrice != nil {
|
||||
resolved.BasePricing.CacheCreationPricePerToken = *chPricing.CacheWritePrice
|
||||
resolved.BasePricing.CacheCreation5mPrice = *chPricing.CacheWritePrice
|
||||
resolved.BasePricing.CacheCreation1hPrice = *chPricing.CacheWritePrice
|
||||
}
|
||||
if chPricing.CacheReadPrice != nil {
|
||||
resolved.BasePricing.CacheReadPricePerToken = *chPricing.CacheReadPrice
|
||||
resolved.BasePricing.CacheReadPricePerTokenPriority = *chPricing.CacheReadPrice
|
||||
}
|
||||
}
|
||||
|
||||
// applyRequestTierOverrides 应用按次/图片模式的渠道覆盖
|
||||
func (r *ModelPricingResolver) applyRequestTierOverrides(chPricing *ChannelModelPricing, resolved *ResolvedPricing) {
|
||||
resolved.RequestTiers = chPricing.Intervals
|
||||
}
|
||||
|
||||
// GetIntervalPricing 根据 context token 数获取区间定价。
|
||||
// 如果有区间列表,找到匹配区间并构造 ModelPricing;否则直接返回 BasePricing。
|
||||
func (r *ModelPricingResolver) GetIntervalPricing(resolved *ResolvedPricing, totalContextTokens int) *ModelPricing {
|
||||
if len(resolved.Intervals) == 0 {
|
||||
return resolved.BasePricing
|
||||
}
|
||||
|
||||
iv := FindMatchingInterval(resolved.Intervals, totalContextTokens)
|
||||
if iv == nil {
|
||||
return resolved.BasePricing
|
||||
}
|
||||
|
||||
return intervalToModelPricing(iv, resolved.SupportsCacheBreakdown)
|
||||
}
|
||||
|
||||
// intervalToModelPricing 将区间定价转换为 ModelPricing
|
||||
func intervalToModelPricing(iv *PricingInterval, supportsCacheBreakdown bool) *ModelPricing {
|
||||
pricing := &ModelPricing{
|
||||
SupportsCacheBreakdown: supportsCacheBreakdown,
|
||||
}
|
||||
if iv.InputPrice != nil {
|
||||
pricing.InputPricePerToken = *iv.InputPrice
|
||||
pricing.InputPricePerTokenPriority = *iv.InputPrice
|
||||
}
|
||||
if iv.OutputPrice != nil {
|
||||
pricing.OutputPricePerToken = *iv.OutputPrice
|
||||
pricing.OutputPricePerTokenPriority = *iv.OutputPrice
|
||||
}
|
||||
if iv.CacheWritePrice != nil {
|
||||
pricing.CacheCreationPricePerToken = *iv.CacheWritePrice
|
||||
pricing.CacheCreation5mPrice = *iv.CacheWritePrice
|
||||
pricing.CacheCreation1hPrice = *iv.CacheWritePrice
|
||||
}
|
||||
if iv.CacheReadPrice != nil {
|
||||
pricing.CacheReadPricePerToken = *iv.CacheReadPrice
|
||||
pricing.CacheReadPricePerTokenPriority = *iv.CacheReadPrice
|
||||
}
|
||||
return pricing
|
||||
}
|
||||
|
||||
// GetRequestTierPrice 根据层级标签获取按次价格
|
||||
func (r *ModelPricingResolver) GetRequestTierPrice(resolved *ResolvedPricing, tierLabel string) float64 {
|
||||
for _, tier := range resolved.RequestTiers {
|
||||
if tier.TierLabel == tierLabel && tier.PerRequestPrice != nil {
|
||||
return *tier.PerRequestPrice
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetRequestTierPriceByContext 根据 context token 数获取按次价格
|
||||
func (r *ModelPricingResolver) GetRequestTierPriceByContext(resolved *ResolvedPricing, totalContextTokens int) float64 {
|
||||
iv := FindMatchingInterval(resolved.RequestTiers, totalContextTokens)
|
||||
if iv != nil && iv.PerRequestPrice != nil {
|
||||
return *iv.PerRequestPrice
|
||||
}
|
||||
return 0
|
||||
}
|
||||
164
backend/internal/service/model_pricing_resolver_test.go
Normal file
164
backend/internal/service/model_pricing_resolver_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func resolverPtrFloat64(v float64) *float64 { return &v }
|
||||
func resolverPtrInt(v int) *int { return &v }
|
||||
|
||||
func newTestBillingServiceForResolver() *BillingService {
|
||||
bs := &BillingService{
|
||||
fallbackPrices: make(map[string]*ModelPricing),
|
||||
}
|
||||
bs.fallbackPrices["claude-sonnet-4"] = &ModelPricing{
|
||||
InputPricePerToken: 3e-6,
|
||||
OutputPricePerToken: 15e-6,
|
||||
CacheCreationPricePerToken: 3.75e-6,
|
||||
CacheReadPricePerToken: 0.3e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
return bs
|
||||
}
|
||||
|
||||
func TestResolve_NoGroupID(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: nil,
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, BillingModeToken, resolved.Mode)
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
// BillingService.GetModelPricing uses fallback internally, but resolveBasePricing
|
||||
// reports "litellm" when GetModelPricing succeeds (regardless of internal source)
|
||||
require.Equal(t, "litellm", resolved.Source)
|
||||
}
|
||||
|
||||
func TestResolve_UnknownModel(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "unknown-model-xyz",
|
||||
GroupID: nil,
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Nil(t, resolved.BasePricing)
|
||||
// Unknown model: GetModelPricing returns error, source is "fallback"
|
||||
require.Equal(t, "fallback", resolved.Source)
|
||||
}
|
||||
|
||||
func TestGetIntervalPricing_NoIntervals(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
basePricing := &ModelPricing{InputPricePerToken: 5e-6}
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModeToken,
|
||||
BasePricing: basePricing,
|
||||
Intervals: nil,
|
||||
}
|
||||
|
||||
result := r.GetIntervalPricing(resolved, 50000)
|
||||
require.Equal(t, basePricing, result)
|
||||
}
|
||||
|
||||
func TestGetIntervalPricing_MatchesInterval(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModeToken,
|
||||
BasePricing: &ModelPricing{InputPricePerToken: 5e-6},
|
||||
SupportsCacheBreakdown: true,
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: resolverPtrInt(128000), InputPrice: resolverPtrFloat64(1e-6), OutputPrice: resolverPtrFloat64(2e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: resolverPtrFloat64(3e-6), OutputPrice: resolverPtrFloat64(6e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
result := r.GetIntervalPricing(resolved, 50000)
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 1e-6, result.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 2e-6, result.OutputPricePerToken, 1e-12)
|
||||
require.True(t, result.SupportsCacheBreakdown)
|
||||
|
||||
result2 := r.GetIntervalPricing(resolved, 200000)
|
||||
require.NotNil(t, result2)
|
||||
require.InDelta(t, 3e-6, result2.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetIntervalPricing_NoMatch_FallsBackToBase(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
basePricing := &ModelPricing{InputPricePerToken: 99e-6}
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModeToken,
|
||||
BasePricing: basePricing,
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 10000, MaxTokens: resolverPtrInt(50000), InputPrice: resolverPtrFloat64(1e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
result := r.GetIntervalPricing(resolved, 5000)
|
||||
require.Equal(t, basePricing, result)
|
||||
}
|
||||
|
||||
func TestGetRequestTierPrice(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: resolverPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: resolverPtrFloat64(0.08)},
|
||||
},
|
||||
}
|
||||
|
||||
require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
|
||||
require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12)
|
||||
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "4K"), 1e-12)
|
||||
}
|
||||
|
||||
func TestGetRequestTierPriceByContext(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: resolverPtrInt(128000), PerRequestPrice: resolverPtrFloat64(0.05)},
|
||||
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: resolverPtrFloat64(0.10)},
|
||||
},
|
||||
}
|
||||
|
||||
require.InDelta(t, 0.05, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12)
|
||||
require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12)
|
||||
}
|
||||
|
||||
func TestGetRequestTierPrice_NilPerRequestPrice(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: nil},
|
||||
},
|
||||
}
|
||||
|
||||
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
|
||||
}
|
||||
@@ -490,4 +490,6 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideScheduledTestService,
|
||||
ProvideScheduledTestRunnerService,
|
||||
NewGroupCapacityService,
|
||||
NewChannelService,
|
||||
NewModelPricingResolver,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user