HEX
Server: Apache/2.4.54 (Win64) OpenSSL/1.1.1p PHP/7.4.30
System: Windows NT website-api 10.0 build 20348 (Windows Server 2016) AMD64
User: SYSTEM (0)
PHP: 7.4.30
Disabled: NONE
Upload Files
File: C:/github_repos/casibase_customer_0022/model/huggingface.go
// Copyright 2023 The Casibase Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package model

import (
	"context"
	"fmt"
	"io"
	"strings"

	"github.com/casibase/casibase/i18n"
	"github.com/casibase/casibase/proxy"
	"github.com/hupe1980/go-huggingface"
)

type HuggingFaceModelProvider struct {
	subType     string
	secretKey   string
	temperature float32
}

func NewHuggingFaceModelProvider(subType string, secretKey string, temperature float32) (*HuggingFaceModelProvider, error) {
	return &HuggingFaceModelProvider{subType: subType, secretKey: secretKey, temperature: temperature}, nil
}

func (p *HuggingFaceModelProvider) GetPricing() string {
	return `URL:
https://huggingface.co/pricing

Not charged
`
}

func (p *HuggingFaceModelProvider) calculatePrice(modelResult *ModelResult) error {
	modelResult.Currency = "USD"
	return nil
}

func (p *HuggingFaceModelProvider) QueryText(question string, writer io.Writer, history []*RawMessage, prompt string, knowledgeMessages []*RawMessage, agentInfo *AgentInfo, lang string) (*ModelResult, error) {
	ctx := context.Background()
	client := huggingface.NewInferenceClient(p.secretKey, func(o *huggingface.InferenceClientOptions) {
		o.HTTPClient = proxy.ProxyHttpClient
	})

	if strings.HasPrefix(question, "$CasibaseDryRun$") {
		modelResult, err := getDefaultModelResult(p.subType, question, "")
		if err != nil {
			return nil, fmt.Errorf(i18n.Translate(lang, "model:cannot calculate tokens"))
		}
		if getContextLength(p.subType) > modelResult.TotalTokenCount {
			return modelResult, nil
		} else {
			return nil, fmt.Errorf(i18n.Translate(lang, "model:exceed max tokens"))
		}
	}

	resp, err := client.TextGeneration(ctx, &huggingface.TextGenerationRequest{
		Inputs: question,
		Parameters: huggingface.TextGenerationParameters{
			Temperature: huggingface.PTR(float64(p.temperature)),
		},
		Options: huggingface.Options{
			WaitForModel: huggingface.PTR(true),
		},
		Model: p.subType,
	})
	if err != nil {
		return nil, err
	}

	respText := strings.Split(resp[0].GeneratedText, "\n")[0]

	_, err = fmt.Fprint(writer, respText)
	if err != nil {
		return nil, err
	}

	modelResult, err := getDefaultModelResult(p.subType, question, respText)
	if err != nil {
		return nil, err
	}

	err = p.calculatePrice(modelResult)
	if err != nil {
		return nil, err
	}

	return modelResult, nil
}