|
|
@@ -0,0 +1,195 @@
|
|
|
+package com.goafanti.common.utils;
|
|
|
+
|
|
|
+import com.alibaba.fastjson.JSON;
|
|
|
+import com.alibaba.fastjson.JSONObject;
|
|
|
+import com.goafanti.baiduAI.BaiduChatErrorEnums;
|
|
|
+import com.goafanti.baiduAI.bo.*;
|
|
|
+import com.goafanti.common.error.BusinessException;
|
|
|
+import okhttp3.*;
|
|
|
+import org.springframework.beans.factory.annotation.Autowired;
|
|
|
+import org.springframework.beans.factory.annotation.Value;
|
|
|
+import org.springframework.scheduling.annotation.Async;
|
|
|
+
|
|
|
+import java.io.*;
|
|
|
+import java.net.HttpURLConnection;
|
|
|
+import java.net.MalformedURLException;
|
|
|
+import java.net.ProtocolException;
|
|
|
+import java.net.URL;
|
|
|
+import java.nio.charset.StandardCharsets;
|
|
|
+import java.util.Calendar;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.concurrent.TimeUnit;
|
|
|
+
|
|
|
+
|
|
|
+public class BaiduChatUtils {
|
|
|
+
|
|
|
+ @Value(value = "${baidu.ApiKey}")
|
|
|
+ private String baiduApiKey=null;
|
|
|
+
|
|
|
+ @Value(value = "${baidu.SecretKey}")
|
|
|
+ private String baiduSecretKey=null;
|
|
|
+ @Autowired
|
|
|
+ private RedisUtil redisUtil;
|
|
|
+
|
|
|
+ /*文心一言地址*/
|
|
|
+ private static final String BAIDU_CHAT_WXYY_URL="https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=";
|
|
|
+ /*Ernie-Lite地址*/
|
|
|
+ private static final String BAIDU_CHAT_ERNIE_LITE_URL="https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=";
|
|
|
+ /*accessToken获取地址*/
|
|
|
+ private static final String BAIDU_ACCESSTOKEN_URL="https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&";
|
|
|
+
|
|
|
+
|
|
|
+ static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().connectTimeout(120000, TimeUnit.MILLISECONDS)
|
|
|
+ .readTimeout(120000, TimeUnit.MILLISECONDS)
|
|
|
+ .build();
|
|
|
+
|
|
|
+
|
|
|
+ @Async
|
|
|
+ public void sendBaiduAiStream(InputSendChat inputSendChat) throws IOException {
|
|
|
+ SseResult res = null;
|
|
|
+ BufferedReader reader=null;
|
|
|
+ InputStreamReader inputStreamReader=null;
|
|
|
+ OutputStream outputStream=null;
|
|
|
+ try {
|
|
|
+ String baidu_url=BAIDU_CHAT_WXYY_URL+getRedisBaiduAccessToken();
|
|
|
+ URL url = new URL(baidu_url);
|
|
|
+ HttpURLConnection connection = (HttpURLConnection) url.openConnection();
|
|
|
+ connection.setRequestMethod("POST");
|
|
|
+ connection.setRequestProperty("Content-Type", "application/json");
|
|
|
+ connection.setDoInput(true);
|
|
|
+ connection.setDoOutput(true);
|
|
|
+ // 构造请求体
|
|
|
+// String requestBody = "{\"messages\":[{\"role\":\"user\",\"content\":\"给我介绍一条从四川自驾到拉萨的路线\"}],\"stream\":true}";
|
|
|
+ String requestBody=JSON.toJSONString(inputSendChat);
|
|
|
+ byte[] postData = requestBody.getBytes(StandardCharsets.UTF_8);
|
|
|
+ connection.setRequestProperty("Content-Length", String.valueOf(postData.length));
|
|
|
+ outputStream =connection.getOutputStream();
|
|
|
+ outputStream.write(postData);
|
|
|
+ InputStream responseStream = connection.getInputStream();
|
|
|
+ inputStreamReader = new InputStreamReader(responseStream, "UTF-8");
|
|
|
+ reader = new BufferedReader(inputStreamReader);
|
|
|
+ String line;
|
|
|
+ res = SseMap.sseEmitterMap.get(inputSendChat.getUserId());
|
|
|
+ while ((line = reader.readLine())!= null) {
|
|
|
+ // 每行数据中以 "data:" 开头的部分即为实际的响应数据
|
|
|
+ if (StringUtils.isNotBlank(line)){
|
|
|
+ if (line.startsWith("data:")) {
|
|
|
+ String data = line.substring("data:".length()).trim();
|
|
|
+ JSONObject jsonObject = JSONObject.parseObject(data);
|
|
|
+ Boolean isEnd = jsonObject.getBoolean("is_end");
|
|
|
+ res.sseEmitter.send(data);
|
|
|
+ if(isEnd){
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ //错误返回格式{"error_code":110,"error_msg":"Access token invalid or no longer valid"}
|
|
|
+ }else if(line.startsWith("{")) {
|
|
|
+ JSONObject jsonObject = JSONObject.parseObject(line);
|
|
|
+ Integer errorCode = jsonObject.getInteger("error_code");
|
|
|
+ OutChatER out=new OutChatER();
|
|
|
+ if (errorCode!=null){
|
|
|
+ out.setError_code(errorCode);
|
|
|
+ if (errorCode.equals("336003")){
|
|
|
+ out.setError_msg(BaiduChatErrorEnums.BycodeGetMsg(errorCode)+jsonObject.getString("error_msg"));
|
|
|
+ }else {
|
|
|
+ out.setError_msg(BaiduChatErrorEnums.BycodeGetMsg(errorCode));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ String errorStr=JSON.toJSONString(out);
|
|
|
+ LoggerUtils.debug(getClass(),errorStr);
|
|
|
+ res.sseEmitter.send(errorStr);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ res.sseEmitter.complete();
|
|
|
+ } catch (MalformedURLException e) {
|
|
|
+ e.printStackTrace();
|
|
|
+ sendJitaoBaiWen(inputSendChat.getUserId(),"data:{\"error_code\":\"2\",\"error_msg\":\"域名解析异常\"}");
|
|
|
+ } catch (ProtocolException e) {
|
|
|
+ e.printStackTrace();
|
|
|
+ }catch (IllegalStateException e){
|
|
|
+ LoggerUtils.debug(getClass(),"前端网页已关闭");
|
|
|
+ }finally {
|
|
|
+ inputStreamReader.close();
|
|
|
+ outputStream.close();
|
|
|
+ reader.close();
|
|
|
+ res.sseEmitter.complete();
|
|
|
+ SseMap.sseEmitterMap.remove(inputSendChat.getUserId());
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ public String getBaiduAccessToken() throws IOException {
|
|
|
+ MediaType mediaType = MediaType.parse("application/json");
|
|
|
+ RequestBody body = RequestBody.create(mediaType, "");
|
|
|
+ StringBuffer url= new StringBuffer(BAIDU_ACCESSTOKEN_URL)
|
|
|
+ .append("client_id=").append(baiduApiKey).append("&client_secret=").append(baiduSecretKey);
|
|
|
+ Request request = new Request.Builder()
|
|
|
+ .url(url.toString())
|
|
|
+ .method("POST", body)
|
|
|
+ .addHeader("Content-Type", "application/json")
|
|
|
+ .addHeader("Accept", "application/json")
|
|
|
+ .build();
|
|
|
+ Response response = HTTP_CLIENT.newCall(request).execute();
|
|
|
+ String result=response.body().string();
|
|
|
+ HashMap<String,Object> map = JSON.parseObject(result, HashMap.class);
|
|
|
+ String accessToken = map.get("access_token").toString();
|
|
|
+ LoggerUtils.debug(getClass(),"获取accessToken="+accessToken);
|
|
|
+ return accessToken;
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ private void sendJitaoBaiWen(String userId, String s) throws IOException {
|
|
|
+ SseResult res = SseMap.sseEmitterMap.get(userId);
|
|
|
+ res.sseEmitter.send(s);
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ private String getRedisBaiduAccessToken() {
|
|
|
+ String redisAccessToken;
|
|
|
+ String redisTime=redisUtil.getString("baiduAccessTime");
|
|
|
+ if (redisTime !=null){
|
|
|
+ Calendar cal = Calendar.getInstance();
|
|
|
+ Long redisAccessTime=Long.valueOf(redisTime);
|
|
|
+ if (cal.getTimeInMillis()>redisAccessTime){
|
|
|
+ redisUtil.deleteString("baiduAccessToken");
|
|
|
+ redisUtil.deleteString("baiduAccessTime");
|
|
|
+ redisAccessToken=pushRedisBaiduAccessToken();
|
|
|
+ LoggerUtils.debug(getClass(),"accessToken过期,重新获取!");
|
|
|
+ }else {
|
|
|
+ redisAccessToken= redisUtil.getString("baiduAccessToken");
|
|
|
+ }
|
|
|
+ }else {
|
|
|
+ redisAccessToken=pushRedisBaiduAccessToken();
|
|
|
+ LoggerUtils.debug(getClass(),"accessToken不存在,从百度获取!");
|
|
|
+ }
|
|
|
+ return redisAccessToken;
|
|
|
+ }
|
|
|
+
|
|
|
+ private String pushRedisBaiduAccessToken() {
|
|
|
+ String baiduAccessToken;
|
|
|
+ try {
|
|
|
+ baiduAccessToken = getBaiduAccessToken();
|
|
|
+ } catch (IOException e) {
|
|
|
+ throw new BusinessException("baiduAccessToken获取失败");
|
|
|
+ }
|
|
|
+ //获取当前系统时间
|
|
|
+ Calendar cal = Calendar.getInstance();
|
|
|
+ //将时间增加三十天
|
|
|
+ cal.add(Calendar.DATE, 30);
|
|
|
+ //获取改变后的时间
|
|
|
+ Long baiduAccessTime= cal.getTimeInMillis();
|
|
|
+ redisUtil.deleteString("baiduAccessToken");
|
|
|
+ redisUtil.deleteString("baiduAccessTime");
|
|
|
+ redisUtil.setString("baiduAccessToken",baiduAccessToken);
|
|
|
+ redisUtil.setString("baiduAccessTime",baiduAccessTime.toString());
|
|
|
+ return baiduAccessToken;
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+}
|