A privacy-first, self-hosted, fully open source personal knowledge management software, written in typescript and golang. (PERSONAL FORK)
at lambda-fork/main 183 lines 4.7 kB view raw
1// SiYuan - Refactor your thinking 2// Copyright (c) 2020-present, b3log.org 3// 4// This program is free software: you can redistribute it and/or modify 5// it under the terms of the GNU Affero General Public License as published by 6// the Free Software Foundation, either version 3 of the License, or 7// (at your option) any later version. 8// 9// This program is distributed in the hope that it will be useful, 10// but WITHOUT ANY WARRANTY; without even the implied warranty of 11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12// GNU Affero General Public License for more details. 13// 14// You should have received a copy of the GNU Affero General Public License 15// along with this program. If not, see <https://www.gnu.org/licenses/>. 16 17package model 18 19import ( 20 "bytes" 21 "strings" 22 23 "github.com/88250/lute/ast" 24 "github.com/88250/lute/parse" 25 "github.com/sashabaranov/go-openai" 26 "github.com/siyuan-note/siyuan/kernel/treenode" 27 "github.com/siyuan-note/siyuan/kernel/util" 28) 29 30func ChatGPT(msg string) (ret string) { 31 if !isOpenAIAPIEnabled() { 32 return 33 } 34 35 return chatGPT(msg, false) 36} 37 38func ChatGPTWithAction(ids []string, action string) (ret string) { 39 if !isOpenAIAPIEnabled() { 40 return 41 } 42 43 if "Clear context" == action { 44 // AI clear context action https://github.com/siyuan-note/siyuan/issues/10255 45 cachedContextMsg = nil 46 return 47 } 48 49 msg := getBlocksContent(ids) 50 ret = chatGPTWithAction(msg, action, false) 51 return 52} 53 54var cachedContextMsg []string 55 56func chatGPT(msg string, cloud bool) (ret string) { 57 if "Clear context" == strings.TrimSpace(msg) { 58 // AI clear context action https://github.com/siyuan-note/siyuan/issues/10255 59 cachedContextMsg = nil 60 return 61 } 62 63 ret, retCtxMsgs, err := chatGPTContinueWrite(msg, cachedContextMsg, cloud) 64 if err != nil { 65 return 66 } 67 cachedContextMsg = append(cachedContextMsg, retCtxMsgs...) 68 return 69} 70 71func chatGPTWithAction(msg string, action string, cloud bool) (ret string) { 72 action = strings.TrimSpace(action) 73 if "" != action { 74 msg = action + ":\n\n" + msg 75 } 76 ret, _, err := chatGPTContinueWrite(msg, nil, cloud) 77 if err != nil { 78 return 79 } 80 return 81} 82 83func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret string, retContextMsgs []string, err error) { 84 util.PushEndlessProgress("Requesting...") 85 defer util.ClearPushProgress(100) 86 87 if Conf.AI.OpenAI.APIMaxContexts < len(contextMsgs) { 88 contextMsgs = contextMsgs[len(contextMsgs)-Conf.AI.OpenAI.APIMaxContexts:] 89 } 90 91 var gpt GPT 92 if cloud { 93 gpt = &CloudGPT{} 94 } else { 95 gpt = &OpenAIGPT{c: util.NewOpenAIClient(Conf.AI.OpenAI.APIKey, Conf.AI.OpenAI.APIProxy, Conf.AI.OpenAI.APIBaseURL, Conf.AI.OpenAI.APIUserAgent, Conf.AI.OpenAI.APIVersion, Conf.AI.OpenAI.APIProvider)} 96 } 97 98 buf := &bytes.Buffer{} 99 for i := 0; i < Conf.AI.OpenAI.APIMaxContexts; i++ { 100 part, stop, chatErr := gpt.chat(msg, contextMsgs) 101 buf.WriteString(part) 102 103 if stop || nil != chatErr { 104 break 105 } 106 107 util.PushEndlessProgress("Continue requesting...") 108 } 109 110 ret = buf.String() 111 ret = strings.TrimSpace(ret) 112 if "" != ret { 113 retContextMsgs = append(retContextMsgs, msg, ret) 114 } 115 return 116} 117 118func isOpenAIAPIEnabled() bool { 119 if "" == Conf.AI.OpenAI.APIKey { 120 util.PushMsg(Conf.Language(193), 5000) 121 return false 122 } 123 return true 124} 125 126func getBlocksContent(ids []string) string { 127 var nodes []*ast.Node 128 trees := map[string]*parse.Tree{} 129 for _, id := range ids { 130 bt := treenode.GetBlockTree(id) 131 if nil == bt { 132 continue 133 } 134 135 var tree *parse.Tree 136 if tree = trees[bt.RootID]; nil == tree { 137 tree, _ = LoadTreeByBlockID(bt.RootID) 138 if nil == tree { 139 continue 140 } 141 142 trees[bt.RootID] = tree 143 } 144 145 if node := treenode.GetNodeInTree(tree, id); nil != node { 146 if ast.NodeDocument == node.Type { 147 for child := node.FirstChild; nil != child; child = child.Next { 148 nodes = append(nodes, child) 149 } 150 } else { 151 nodes = append(nodes, node) 152 } 153 } 154 } 155 156 luteEngine := util.NewLute() 157 buf := bytes.Buffer{} 158 for _, node := range nodes { 159 md := treenode.ExportNodeStdMd(node, luteEngine) 160 buf.WriteString(md) 161 buf.WriteString("\n\n") 162 } 163 return buf.String() 164} 165 166type GPT interface { 167 chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) 168} 169 170type OpenAIGPT struct { 171 c *openai.Client 172} 173 174func (gpt *OpenAIGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) { 175 return util.ChatGPT(msg, contextMsgs, gpt.c, Conf.AI.OpenAI.APIModel, Conf.AI.OpenAI.APIMaxTokens, Conf.AI.OpenAI.APITemperature, Conf.AI.OpenAI.APITimeout) 176} 177 178type CloudGPT struct { 179} 180 181func (gpt *CloudGPT) chat(msg string, contextMsgs []string) (partRet string, stop bool, err error) { 182 return CloudChatGPT(msg, contextMsgs) 183}