package com.goafanti.common.utils; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.goafanti.baiduAI.BaiduChatErrorEnums; import com.goafanti.baiduAI.bo.*; import com.goafanti.common.error.BusinessException; import okhttp3.*; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Component; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.net.HttpURLConnection; import java.net.MalformedURLException; import java.net.ProtocolException; import java.net.URL; import java.nio.charset.StandardCharsets; import java.util.Calendar; import java.util.HashMap; import java.util.Map; import java.util.concurrent.TimeUnit; public class BaiduChatUtils { @Value(value = "${baidu.ApiKey}") private String baiduApiKey=null; @Value(value = "${baidu.SecretKey}") private String baiduSecretKey=null; @Autowired private RedisUtil redisUtil; /*文心一言地址*/ private static final String BAIDU_CHAT_WXYY_URL="https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token="; /*Ernie-Lite地址*/ private static final String BAIDU_CHAT_ERNIE_LITE_URL="https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token="; /*accessToken获取地址*/ private static final String BAIDU_ACCESSTOKEN_URL="https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&"; static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().connectTimeout(120000, TimeUnit.MILLISECONDS) .readTimeout(120000, TimeUnit.MILLISECONDS) .build(); public String getBaiduAccessToken() throws IOException { MediaType mediaType = MediaType.parse("application/json"); RequestBody body = RequestBody.create(mediaType, ""); StringBuffer url= new StringBuffer(BAIDU_ACCESSTOKEN_URL) .append("client_id=").append(baiduApiKey).append("&client_secret=").append(baiduSecretKey); Request request = new Request.Builder() .url(url.toString()) .method("POST", body) .addHeader("Content-Type", "application/json") .addHeader("Accept", "application/json") .build(); Response response = HTTP_CLIENT.newCall(request).execute(); String result=response.body().string(); HashMap map = JSON.parseObject(result, HashMap.class); String accessToken = map.get("access_token").toString(); LoggerUtils.debug(getClass(),"获取accessToken="+accessToken); return accessToken; } public String sendBaiduAI(InputSendChat in) throws IOException{ String accessToken = getRedisBaiduAccessToken(); MediaType mediaType = MediaType.parse("application/json"); RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(in)); Request request = new Request.Builder() .url(BAIDU_CHAT_WXYY_URL + accessToken) .method("POST", body) .addHeader("Content-Type", "application/json") .build(); Response response = HTTP_CLIENT.newCall(request).execute(); String result=response.body().string(); return result; } @Async public void sendBaiduAiStream(InputSendChat inputSendChat) throws IOException { try { String baidu_url=BAIDU_CHAT_WXYY_URL+getRedisBaiduAccessToken(); URL url = new URL(baidu_url); HttpURLConnection connection = (HttpURLConnection) url.openConnection(); connection.setRequestMethod("POST"); connection.setRequestProperty("Content-Type", "application/json"); connection.setDoInput(true); connection.setDoOutput(true); // 构造请求体 // String requestBody = "{\"messages\":[{\"role\":\"user\",\"content\":\"给我介绍一条从四川自驾到拉萨的路线\"}],\"stream\":true}"; String requestBody=JSON.toJSONString(inputSendChat); byte[] postData = requestBody.getBytes(StandardCharsets.UTF_8); connection.setRequestProperty("Content-Length", String.valueOf(postData.length)); connection.getOutputStream().write(postData); InputStream responseStream = connection.getInputStream(); BufferedReader reader = new BufferedReader(new InputStreamReader(responseStream,"UTF-8")); String line; SseResult res = SseMap.sseEmitterMap.get(inputSendChat.getUserId()); while ((line = reader.readLine())!= null) { // 每行数据中以 "data:" 开头的部分即为实际的响应数据 System.out.println("relut="+line+"。"); if (line.startsWith("data:")) { String data = line.substring("data:".length()).trim(); JSONObject jsonObject = JSONObject.parseObject(data); Boolean isEnd = jsonObject.getBoolean("is_end"); if(isEnd){ break; } // OutSendChatOK out =jsonObject.toJavaObject(OutSendChatOK.class); // System.out.println(data); res.sseEmitter.send(data); }else if(line.startsWith("{")) { JSONObject jsonObject = JSONObject.parseObject(line); jsonObject.getString("error_code"); res.sseEmitter.send(line); } } res.sseEmitter.complete(); reader.close(); } catch (MalformedURLException e) { e.printStackTrace(); sendJitaoBaiWen(inputSendChat.getUserId(),"data:{\"error_code\":\"2\",\"error_msg\":\"域名解析异常\"}"); } catch (ProtocolException e) { e.printStackTrace(); } SseMap.sseEmitterMap.remove(inputSendChat.getUserId()); } private OutSendChat pushResultToOutSendChat(String result) { Map resultMap=JSON.parseObject(result, Map.class); Integer errorCode= (Integer) resultMap.get("error_code"); if (errorCode!=null){ OutChatER res=new OutChatER(); res.setErrorCode(errorCode); if (errorCode.equals("336003")){ res.setErrorMsg(BaiduChatErrorEnums.BycodeGetMsg(errorCode)+resultMap.get("error_msg")); }else { res.setErrorMsg(BaiduChatErrorEnums.BycodeGetMsg(errorCode)); } return res; }else { OutSendChatOK res=new OutSendChatOK(); res=JSON.parseObject(result,OutSendChatOK.class); return res; } } private void sendJitaoBaiWen(String userId, String s) throws IOException { SseResult res = SseMap.sseEmitterMap.get(userId); res.sseEmitter.send(s); } private String getRedisBaiduAccessToken() { String redisAccessToken=null; String redisTime=redisUtil.getString("baiduAccessTime"); //没有 if (redisTime !=null){ Calendar cal = Calendar.getInstance(); Long redisAccessTime=Long.valueOf(redisTime); if (cal.getTimeInMillis()>redisAccessTime){ redisUtil.deleteString("baiduAccessToken"); redisUtil.deleteString("baiduAccessTime"); redisAccessToken=pushRedisBaiduAccessToken(); LoggerUtils.debug(getClass(),"accessToken过期,重新获取!"); }else { redisAccessToken= redisUtil.getString("baiduAccessToken"); LoggerUtils.debug(getClass(),"accessToken从redis获取成功!"); } }else { redisAccessToken=pushRedisBaiduAccessToken(); LoggerUtils.debug(getClass(),"accessToken不存在,从百度获取!"); } return redisAccessToken; } private String pushRedisBaiduAccessToken() { String baiduAccessToken; try { baiduAccessToken = getBaiduAccessToken(); } catch (IOException e) { throw new BusinessException("baiduAccessToken获取失败"); } //获取当前系统时间 Calendar cal = Calendar.getInstance(); //将时间增加三十天 cal.add(Calendar.DATE, 30); //获取改变后的时间 Long baiduAccessTime= cal.getTimeInMillis(); redisUtil.setString("baiduAccessToken",baiduAccessToken); redisUtil.setString("baiduAccessTime",baiduAccessTime.toString()); return baiduAccessToken; } }