feat(channels): add "Available Channels" aggregate view
Add a read-only aggregate view per channel: its linked groups and a deterministic wildcard-free supported-model list with pricing details. Backend - service.Channel.SupportedModels(): combine ModelMapping keys with same-platform ModelPricing.Models; trailing "*" keys expand via pricing prefix match; platforms without a mapping produce no entries (intentional "no mapping = not shown" rule). - Extract splitWildcardSuffix() shared with toModelEntry. - Build a per-call pricing lookup map (platform+lowerName -> *pricing) to avoid O(N*M) scans in SupportedModels. - ChannelService.ListAvailable() aggregates channels + active groups; filters out group IDs no longer active. - Admin route GET /api/v1/admin/channels/available returns the full DTO (id, status, billing_model_source, restrict_models, groups, supported_models). - User route GET /api/v1/channels/available applies three filters: Status==active, visible-group intersection, and platform filter on supported_models (prevents cross-platform leak when a channel links to both a user-accessible group and an inaccessible one on another platform). Response is a plain array (matches the /groups/available sibling shape). Field whitelist omits billing_model_source, restrict_models, ids, status, sort_order. Frontend - New /admin/available-channels and /available-channels views backed by a shared AvailableChannelsTable component (admin adds status + billing-source columns via slots). - PricingRow extracted to its own SFC; SupportedModelChip references shared billing-mode constants in constants/channel.ts. - Sidebar: new entry above "渠道管理" for admin; matching entry in user nav. - i18n: zh + en coverage for both namespaces. Tests - SupportedModels: wildcard-only pricing skipped, prefix-matches- nothing, cross-platform bleed, case-insensitive dedup, empty platform mapping. - ListAvailable: nil groupRepo, inactive-group-ID dropped, stable case-insensitive name sort. - User handler: 401 on unauthenticated, visible-group intersection, platform filter on supported_models, JSON whitelist. - Admin handler: full DTO including default BillingModelSource fallback. Refs: issue #1729
This commit is contained in:
@@ -345,3 +345,175 @@ type ChannelUsageFields struct {
|
||||
BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped"
|
||||
ModelMappingChain string // 映射链描述,如 "a→b→c"
|
||||
}
|
||||
|
||||
// SupportedModel 渠道的一个支持模型条目(无通配符、可直接展示给用户)
|
||||
type SupportedModel struct {
|
||||
Name string // 用户侧模型名
|
||||
Platform string // 所属平台
|
||||
Pricing *ChannelModelPricing // 定价详情(nil 表示未配置定价)
|
||||
}
|
||||
|
||||
// wildcardSuffix 是模型模式中的通配符后缀标记(仅支持尾部匹配)。
|
||||
const wildcardSuffix = "*"
|
||||
|
||||
// splitWildcardSuffix 将模型模式拆分为 (prefix, isWildcard)。
|
||||
//
|
||||
// "claude-opus-*" → ("claude-opus-", true)
|
||||
// "claude-opus-4" → ("claude-opus-4", false)
|
||||
// "*" → ("", true)
|
||||
//
|
||||
// 注意:返回的 prefix 保持原始大小写,由调用方按需 ToLower。
|
||||
func splitWildcardSuffix(pattern string) (prefix string, isWildcard bool) {
|
||||
if strings.HasSuffix(pattern, wildcardSuffix) {
|
||||
return strings.TrimSuffix(pattern, wildcardSuffix), true
|
||||
}
|
||||
return pattern, false
|
||||
}
|
||||
|
||||
// GetModelPricingByPlatform 在指定平台下查找精确模型的定价,未找到返回 nil。
|
||||
// 与 GetModelPricing 的区别:按 Platform 隔离,避免跨平台同名模型误匹配。
|
||||
func (c *Channel) GetModelPricingByPlatform(platform, model string) *ChannelModelPricing {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
modelLower := strings.ToLower(model)
|
||||
for i := range c.ModelPricing {
|
||||
if c.ModelPricing[i].Platform != platform {
|
||||
continue
|
||||
}
|
||||
for _, m := range c.ModelPricing[i].Models {
|
||||
if strings.ToLower(m) == modelLower {
|
||||
cp := c.ModelPricing[i].Clone()
|
||||
return &cp
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// pricingLookup 是渠道定价在单个计算过程中的索引:platform → (lowerName → *pricing)。
|
||||
// 用于将 SupportedModels 的定价解析从 O(N*M) 降到 O(N+M)。
|
||||
type pricingLookup map[string]map[string]*ChannelModelPricing
|
||||
|
||||
// buildPricingLookup 对渠道的定价列表做一次扫描,生成 platform+模型名 的索引。
|
||||
// 索引值是定价条目的 Clone 指针,调用方可安全按需返回副本而不污染缓存。
|
||||
// wildcard 后缀(如 "claude-*")不会被索引(它们不是精确模型名)。
|
||||
func buildPricingLookup(pricings []ChannelModelPricing) pricingLookup {
|
||||
lookup := make(pricingLookup, len(pricings))
|
||||
for i := range pricings {
|
||||
p := pricings[i]
|
||||
byModel, ok := lookup[p.Platform]
|
||||
if !ok {
|
||||
byModel = make(map[string]*ChannelModelPricing, len(p.Models))
|
||||
lookup[p.Platform] = byModel
|
||||
}
|
||||
for _, m := range p.Models {
|
||||
if _, wild := splitWildcardSuffix(m); wild {
|
||||
continue
|
||||
}
|
||||
lower := strings.ToLower(m)
|
||||
if _, exists := byModel[lower]; exists {
|
||||
continue // 首个命中胜出(保持 case-insensitive 去重后第一个定价)
|
||||
}
|
||||
cp := pricings[i].Clone()
|
||||
byModel[lower] = &cp
|
||||
}
|
||||
}
|
||||
return lookup
|
||||
}
|
||||
|
||||
// pricedNamesFor 返回指定平台下已索引的精确模型名(保留原始大小写,按添加顺序)。
|
||||
// 它是从 pricingLookup 中取 keys 并回查原始 ModelPricing 以得到原样字符串。
|
||||
func pricedNamesFor(pricings []ChannelModelPricing, platform string) []string {
|
||||
seen := make(map[string]struct{})
|
||||
out := make([]string, 0)
|
||||
for i := range pricings {
|
||||
if pricings[i].Platform != platform {
|
||||
continue
|
||||
}
|
||||
for _, m := range pricings[i].Models {
|
||||
if _, wild := splitWildcardSuffix(m); wild {
|
||||
continue
|
||||
}
|
||||
lower := strings.ToLower(m)
|
||||
if _, ok := seen[lower]; ok {
|
||||
continue
|
||||
}
|
||||
seen[lower] = struct{}{}
|
||||
out = append(out, m)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// SupportedModels 计算渠道的支持模型列表,结果保证不含通配符。
|
||||
//
|
||||
// 算法(以渠道自身的 ModelMapping 为唯一入口):
|
||||
// - 遍历 Channel.ModelMapping 的每个 platform 条目;
|
||||
// - 映射 key 不带尾部 "*":直接作为一个支持模型名(即使没有匹配的定价行,也会产出 Pricing=nil 的条目);
|
||||
// - 映射 key 带尾部 "*":用同 platform 的 ModelPricing.Models 做前缀匹配展开(定价中带 "*" 的条目被忽略,因为它们本身就是模式,不是具体模型名);
|
||||
// - 未在 ModelMapping 中出现的 platform 不会产出任何条目——这是**刻意设计**("没配映射就不显示"),即使该平台有定价行。
|
||||
//
|
||||
// 每个结果尝试从 pricingLookup(平台+模型名索引)查找精确定价,未配置则 Pricing=nil。
|
||||
// 结果按 (Platform, Name) 稳定排序,并按 (Platform, lowercase(Name)) 去重。
|
||||
func (c *Channel) SupportedModels() []SupportedModel {
|
||||
if c == nil || len(c.ModelMapping) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
lookup := buildPricingLookup(c.ModelPricing)
|
||||
|
||||
type dedupKey struct {
|
||||
platform string
|
||||
name string
|
||||
}
|
||||
seen := make(map[dedupKey]struct{})
|
||||
result := make([]SupportedModel, 0)
|
||||
|
||||
add := func(platform, name string) {
|
||||
key := dedupKey{platform: platform, name: strings.ToLower(name)}
|
||||
if _, ok := seen[key]; ok {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
var pricing *ChannelModelPricing
|
||||
if byModel, ok := lookup[platform]; ok {
|
||||
if p, ok := byModel[strings.ToLower(name)]; ok {
|
||||
pricing = p
|
||||
}
|
||||
}
|
||||
result = append(result, SupportedModel{
|
||||
Name: name,
|
||||
Platform: platform,
|
||||
Pricing: pricing,
|
||||
})
|
||||
}
|
||||
|
||||
for platform, mapping := range c.ModelMapping {
|
||||
if len(mapping) == 0 {
|
||||
continue
|
||||
}
|
||||
pricedNames := pricedNamesFor(c.ModelPricing, platform)
|
||||
for src := range mapping {
|
||||
prefix, isWild := splitWildcardSuffix(src)
|
||||
if isWild {
|
||||
prefixLower := strings.ToLower(prefix)
|
||||
for _, candidate := range pricedNames {
|
||||
if strings.HasPrefix(strings.ToLower(candidate), prefixLower) {
|
||||
add(platform, candidate)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
add(platform, src)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
if result[i].Platform != result[j].Platform {
|
||||
return result[i].Platform < result[j].Platform
|
||||
}
|
||||
return result[i].Name < result[j].Name
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
84
backend/internal/service/channel_available.go
Normal file
84
backend/internal/service/channel_available.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AvailableGroupRef 渠道视图中关联分组的简要信息。
|
||||
type AvailableGroupRef struct {
|
||||
ID int64
|
||||
Name string
|
||||
Platform string
|
||||
}
|
||||
|
||||
// AvailableChannel 可用渠道视图:用于「可用渠道」页面展示渠道基础信息 +
|
||||
// 关联的分组 + 推导出的支持模型列表(无通配符)。
|
||||
type AvailableChannel struct {
|
||||
ID int64
|
||||
Name string
|
||||
Description string
|
||||
Status string
|
||||
BillingModelSource string
|
||||
RestrictModels bool
|
||||
Groups []AvailableGroupRef
|
||||
SupportedModels []SupportedModel
|
||||
}
|
||||
|
||||
// ListAvailable 返回所有渠道的可用视图:每个渠道附带关联分组信息与支持模型列表。
|
||||
//
|
||||
// 支持模型通过 (*Channel).SupportedModels() 计算得到(见 channel.go)。
|
||||
// 关联分组信息通过 groupRepo.ListActive 查询后按 ID 映射;渠道 GroupIDs 中未在活跃列表中
|
||||
// 的分组(已停用或删除)会被忽略。
|
||||
func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel, error) {
|
||||
channels, err := s.repo.ListAll(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list channels: %w", err)
|
||||
}
|
||||
|
||||
groupByID := make(map[int64]AvailableGroupRef)
|
||||
if s.groupRepo != nil {
|
||||
groups, err := s.groupRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list active groups: %w", err)
|
||||
}
|
||||
for i := range groups {
|
||||
g := groups[i]
|
||||
groupByID[g.ID] = AvailableGroupRef{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Platform: g.Platform,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]AvailableChannel, 0, len(channels))
|
||||
for i := range channels {
|
||||
ch := &channels[i]
|
||||
groups := make([]AvailableGroupRef, 0, len(ch.GroupIDs))
|
||||
for _, gid := range ch.GroupIDs {
|
||||
if ref, ok := groupByID[gid]; ok {
|
||||
groups = append(groups, ref)
|
||||
}
|
||||
}
|
||||
sort.Slice(groups, func(i, j int) bool { return groups[i].Name < groups[j].Name })
|
||||
|
||||
out = append(out, AvailableChannel{
|
||||
ID: ch.ID,
|
||||
Name: ch.Name,
|
||||
Description: ch.Description,
|
||||
Status: ch.Status,
|
||||
BillingModelSource: ch.BillingModelSource,
|
||||
RestrictModels: ch.RestrictModels,
|
||||
Groups: groups,
|
||||
SupportedModels: ch.SupportedModels(),
|
||||
})
|
||||
}
|
||||
|
||||
sort.SliceStable(out, func(i, j int) bool {
|
||||
return strings.ToLower(out[i].Name) < strings.ToLower(out[j].Name)
|
||||
})
|
||||
return out, nil
|
||||
}
|
||||
119
backend/internal/service/channel_available_test.go
Normal file
119
backend/internal/service/channel_available_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// stubGroupRepoForAvailable 是 ListAvailable 测试用的 GroupRepository stub,
|
||||
// 仅实现 ListActive;其他方法对本测试无关,返回零值即可。
|
||||
type stubGroupRepoForAvailable struct {
|
||||
activeGroups []Group
|
||||
}
|
||||
|
||||
func (s *stubGroupRepoForAvailable) ListActive(ctx context.Context) ([]Group, error) {
|
||||
return s.activeGroups, nil
|
||||
}
|
||||
|
||||
func (s *stubGroupRepoForAvailable) Create(ctx context.Context, group *Group) error { return nil }
|
||||
func (s *stubGroupRepoForAvailable) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubGroupRepoForAvailable) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubGroupRepoForAvailable) Update(ctx context.Context, group *Group) error { return nil }
|
||||
func (s *stubGroupRepoForAvailable) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (s *stubGroupRepoForAvailable) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubGroupRepoForAvailable) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubGroupRepoForAvailable) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubGroupRepoForAvailable) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubGroupRepoForAvailable) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (s *stubGroupRepoForAvailable) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
|
||||
return 0, 0, nil
|
||||
}
|
||||
func (s *stubGroupRepoForAvailable) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (s *stubGroupRepoForAvailable) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubGroupRepoForAvailable) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubGroupRepoForAvailable) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// newAvailableChannelService 构造一个 ChannelService,channelRepo.ListAll 返回给定 channels,
|
||||
// groupRepo 由参数决定(可传 nil 测试 nil 分支)。
|
||||
func newAvailableChannelService(channels []Channel, groupRepo GroupRepository) *ChannelService {
|
||||
repo := &mockChannelRepository{
|
||||
listAllFn: func(ctx context.Context) ([]Channel, error) { return channels, nil },
|
||||
}
|
||||
return NewChannelService(repo, groupRepo, nil)
|
||||
}
|
||||
|
||||
func TestListAvailable_NilGroupRepo_NoGroupsAttached(t *testing.T) {
|
||||
// groupRepo 为 nil 时不应 panic,且每个渠道的 Groups 应为空切片。
|
||||
channels := []Channel{{
|
||||
ID: 1,
|
||||
Name: "chA",
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10, 20},
|
||||
}}
|
||||
svc := newAvailableChannelService(channels, nil)
|
||||
out, err := svc.ListAvailable(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out, 1)
|
||||
require.Empty(t, out[0].Groups)
|
||||
}
|
||||
|
||||
func TestListAvailable_InactiveGroupIDSilentlyDropped(t *testing.T) {
|
||||
// 渠道 GroupIDs 中引用的 group 未出现在 ListActive 结果中(已停用或删除),应被静默丢弃。
|
||||
channels := []Channel{{
|
||||
ID: 1,
|
||||
Name: "chA",
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{1, 99},
|
||||
}}
|
||||
groupRepo := &stubGroupRepoForAvailable{
|
||||
activeGroups: []Group{{ID: 1, Name: "g1", Platform: "anthropic"}},
|
||||
}
|
||||
svc := newAvailableChannelService(channels, groupRepo)
|
||||
out, err := svc.ListAvailable(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out, 1)
|
||||
require.Len(t, out[0].Groups, 1)
|
||||
require.Equal(t, int64(1), out[0].Groups[0].ID)
|
||||
}
|
||||
|
||||
func TestListAvailable_SortedByName(t *testing.T) {
|
||||
channels := []Channel{
|
||||
{ID: 1, Name: "beta"},
|
||||
{ID: 2, Name: "Alpha"},
|
||||
{ID: 3, Name: "charlie"},
|
||||
}
|
||||
svc := newAvailableChannelService(channels, nil)
|
||||
out, err := svc.ListAvailable(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out, 3)
|
||||
require.Equal(t, "Alpha", out[0].Name)
|
||||
require.Equal(t, "beta", out[1].Name)
|
||||
require.Equal(t, "charlie", out[2].Name)
|
||||
}
|
||||
@@ -141,6 +141,7 @@ const (
|
||||
// ChannelService 渠道管理服务
|
||||
type ChannelService struct {
|
||||
repo ChannelRepository
|
||||
groupRepo GroupRepository
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
|
||||
cache atomic.Value // *channelCache
|
||||
@@ -148,9 +149,10 @@ type ChannelService struct {
|
||||
}
|
||||
|
||||
// NewChannelService 创建渠道服务实例
|
||||
func NewChannelService(repo ChannelRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService {
|
||||
func NewChannelService(repo ChannelRepository, groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService {
|
||||
s := &ChannelService{
|
||||
repo: repo,
|
||||
groupRepo: groupRepo,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
}
|
||||
return s
|
||||
@@ -884,12 +886,7 @@ func conflictsBetween(a, b modelEntry) bool {
|
||||
|
||||
// toModelEntry 将模型名转换为 modelEntry
|
||||
func toModelEntry(pattern string) modelEntry {
|
||||
lower := strings.ToLower(pattern)
|
||||
isWild := strings.HasSuffix(lower, "*")
|
||||
prefix := lower
|
||||
if isWild {
|
||||
prefix = strings.TrimSuffix(lower, "*")
|
||||
}
|
||||
prefix, isWild := splitWildcardSuffix(strings.ToLower(pattern))
|
||||
return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild}
|
||||
}
|
||||
|
||||
|
||||
@@ -189,11 +189,11 @@ func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByGroupID(_ context
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func newTestChannelService(repo *mockChannelRepository) *ChannelService {
|
||||
return NewChannelService(repo, nil)
|
||||
return NewChannelService(repo, nil, nil)
|
||||
}
|
||||
|
||||
func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChannelAuthCacheInvalidator) *ChannelService {
|
||||
return NewChannelService(repo, auth)
|
||||
return NewChannelService(repo, nil, auth)
|
||||
}
|
||||
|
||||
// makeStandardRepo returns a repo that serves one active channel with anthropic pricing
|
||||
|
||||
@@ -433,3 +433,207 @@ func TestValidateIntervals_UnboundedNotLast(t *testing.T) {
|
||||
require.Contains(t, err.Error(), "unbounded")
|
||||
require.Contains(t, err.Error(), "last")
|
||||
}
|
||||
|
||||
func TestSupportedModels_ExactKeysAndPricing(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 10, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(3e-6)},
|
||||
{ID: 11, Platform: "anthropic", Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(1.5e-5)},
|
||||
},
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"anthropic": {
|
||||
"claude-sonnet-4-6": "claude-sonnet-4-6",
|
||||
"claude-opus-4-6": "claude-opus-4-6",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := ch.SupportedModels()
|
||||
require.Len(t, got, 2)
|
||||
require.Equal(t, "anthropic", got[0].Platform)
|
||||
require.Equal(t, "claude-opus-4-6", got[0].Name)
|
||||
require.NotNil(t, got[0].Pricing)
|
||||
require.Equal(t, int64(11), got[0].Pricing.ID)
|
||||
require.Equal(t, "claude-sonnet-4-6", got[1].Name)
|
||||
require.Equal(t, int64(10), got[1].Pricing.ID)
|
||||
}
|
||||
|
||||
func TestSupportedModels_WildcardExpandedFromPricing(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6", "claude-sonnet-4-5"}},
|
||||
{ID: 2, Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
|
||||
},
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"anthropic": {
|
||||
"claude-sonnet-*": "claude-sonnet-4-6",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := ch.SupportedModels()
|
||||
names := make([]string, 0, len(got))
|
||||
for _, m := range got {
|
||||
names = append(names, m.Name)
|
||||
}
|
||||
require.ElementsMatch(t, []string{"claude-sonnet-4-5", "claude-sonnet-4-6"}, names)
|
||||
for _, m := range got {
|
||||
require.NotContains(t, m.Name, "*")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportedModels_PlatformWithoutMappingSkipped(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
|
||||
{ID: 2, Platform: "openai", Models: []string{"gpt-4o"}},
|
||||
},
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"anthropic": {"claude-sonnet-4-6": "claude-sonnet-4-6"},
|
||||
// openai 没有 mapping 条目
|
||||
},
|
||||
}
|
||||
|
||||
got := ch.SupportedModels()
|
||||
require.Len(t, got, 1)
|
||||
require.Equal(t, "anthropic", got[0].Platform)
|
||||
require.Equal(t, "claude-sonnet-4-6", got[0].Name)
|
||||
}
|
||||
|
||||
func TestSupportedModels_MissingPricingKeepsNilPricing(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"anthropic": {"claude-sonnet-4-6": "claude-sonnet-4-6"},
|
||||
},
|
||||
}
|
||||
|
||||
got := ch.SupportedModels()
|
||||
require.Len(t, got, 1)
|
||||
require.Equal(t, "claude-sonnet-4-6", got[0].Name)
|
||||
require.Nil(t, got[0].Pricing)
|
||||
}
|
||||
|
||||
func TestSupportedModels_DedupAndSort(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6", "claude-sonnet-4-5"}},
|
||||
{ID: 2, Platform: "openai", Models: []string{"gpt-4o"}},
|
||||
},
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"anthropic": {
|
||||
"claude-sonnet-4-6": "upstream-a",
|
||||
"claude-sonnet-*": "upstream-a",
|
||||
},
|
||||
"openai": {"gpt-4o": "gpt-4o"},
|
||||
},
|
||||
}
|
||||
|
||||
got := ch.SupportedModels()
|
||||
require.Len(t, got, 3)
|
||||
require.Equal(t, "anthropic", got[0].Platform)
|
||||
require.Equal(t, "claude-sonnet-4-5", got[0].Name)
|
||||
require.Equal(t, "anthropic", got[1].Platform)
|
||||
require.Equal(t, "claude-sonnet-4-6", got[1].Name)
|
||||
require.Equal(t, "openai", got[2].Platform)
|
||||
require.Equal(t, "gpt-4o", got[2].Name)
|
||||
}
|
||||
|
||||
func TestSupportedModels_NilChannelAndEmpty(t *testing.T) {
|
||||
var nilCh *Channel
|
||||
require.Nil(t, nilCh.SupportedModels())
|
||||
|
||||
empty := &Channel{}
|
||||
require.Nil(t, empty.SupportedModels())
|
||||
}
|
||||
|
||||
func TestGetModelPricingByPlatform(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(3e-6)},
|
||||
{ID: 2, Platform: "openai", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(1e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
ant := ch.GetModelPricingByPlatform("anthropic", "claude-sonnet-4-6")
|
||||
require.NotNil(t, ant)
|
||||
require.Equal(t, int64(1), ant.ID)
|
||||
|
||||
oa := ch.GetModelPricingByPlatform("openai", "claude-sonnet-4-6")
|
||||
require.NotNil(t, oa)
|
||||
require.Equal(t, int64(2), oa.ID)
|
||||
|
||||
require.Nil(t, ch.GetModelPricingByPlatform("gemini", "claude-sonnet-4-6"))
|
||||
}
|
||||
|
||||
func TestSupportedModels_WildcardOnlyPricingRowsSkipped(t *testing.T) {
|
||||
// 定价中含通配符条目(pattern),不应被当作具体模型名展开。
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-*", "claude-sonnet-4-6"}},
|
||||
},
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"anthropic": {"claude-sonnet-*": "claude-sonnet-4-6"},
|
||||
},
|
||||
}
|
||||
got := ch.SupportedModels()
|
||||
require.Len(t, got, 1)
|
||||
require.Equal(t, "claude-sonnet-4-6", got[0].Name)
|
||||
for _, m := range got {
|
||||
require.NotContains(t, m.Name, "*")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportedModels_WildcardPrefixMatchesNothing(t *testing.T) {
|
||||
// 通配符模式无任何对应定价模型时,该平台应产出 0 个模型。
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Platform: "openai", Models: []string{"gpt-4o"}},
|
||||
},
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"anthropic": {"gpt-foo-*": "gpt-foo-1"},
|
||||
},
|
||||
}
|
||||
require.Empty(t, ch.SupportedModels())
|
||||
}
|
||||
|
||||
func TestSupportedModels_CrossPlatformPricingDoesNotBleed(t *testing.T) {
|
||||
// anthropic 的通配符不应拉入 openai 定价行,哪怕名字恰好前缀匹配。
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Platform: "openai", Models: []string{"claude-sonnet-4-6"}},
|
||||
},
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"anthropic": {"claude-sonnet-*": "x"},
|
||||
},
|
||||
}
|
||||
require.Empty(t, ch.SupportedModels())
|
||||
}
|
||||
|
||||
func TestSupportedModels_CaseInsensitiveDedup(t *testing.T) {
|
||||
// 两行定价用不同大小写定义了同一模型,结果应去重为 1 条;首次出现的原始大小写保留。
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Platform: "openai", Models: []string{"GPT-4o"}},
|
||||
{ID: 2, Platform: "openai", Models: []string{"gpt-4o"}},
|
||||
},
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"openai": {"gpt-*": "x"},
|
||||
},
|
||||
}
|
||||
got := ch.SupportedModels()
|
||||
require.Len(t, got, 1)
|
||||
require.Equal(t, "GPT-4o", got[0].Name)
|
||||
}
|
||||
|
||||
func TestSupportedModels_EmptyPlatformMapping(t *testing.T) {
|
||||
// ModelMapping 有一个 platform key 但 value 是空 map —— 该 platform 应被跳过。
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
|
||||
},
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"anthropic": {},
|
||||
},
|
||||
}
|
||||
require.Empty(t, ch.SupportedModels())
|
||||
}
|
||||
|
||||
@@ -184,7 +184,7 @@ func newResolverWithChannel(t *testing.T, pricing []ChannelModelPricing) *ModelP
|
||||
return map[int64]string{groupID: "anthropic"}, nil
|
||||
},
|
||||
}
|
||||
cs := NewChannelService(repo, nil)
|
||||
cs := NewChannelService(repo, nil, nil)
|
||||
bs := newTestBillingServiceForResolver()
|
||||
return NewModelPricingResolver(cs, bs)
|
||||
}
|
||||
@@ -517,7 +517,7 @@ func TestResolve_WithChannelOverride_CacheError(t *testing.T) {
|
||||
return nil, errors.New("database unavailable")
|
||||
},
|
||||
}
|
||||
cs := NewChannelService(repo, nil)
|
||||
cs := NewChannelService(repo, nil, nil)
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(cs, bs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user