This commit is contained in:
ageer
2024-04-01 22:21:29 +08:00
parent cead269b19
commit dea23f13ef
552 changed files with 2144 additions and 154437 deletions

View File

@@ -19,6 +19,7 @@ import java.util.List;
@RequestMapping("/mj/account")
@RequiredArgsConstructor
public class AccountController {
private final DiscordLoadBalancer loadBalancer;
@ApiOperation(value = "指定ID获取账号")

View File

@@ -0,0 +1,100 @@
package com.xmzs.midjourney.controller;
import cn.hutool.json.JSONUtil;
import com.xmzs.common.chat.constant.OpenAIConst;
import com.xmzs.common.core.domain.model.LoginUser;
import com.xmzs.common.core.exception.base.BaseException;
import com.xmzs.common.satoken.utils.LoginHelper;
import com.xmzs.midjourney.domain.InsightFace;
import com.xmzs.system.domain.bo.ChatMessageBo;
import com.xmzs.system.service.IChatMessageService;
import com.xmzs.system.service.IChatService;
import com.xmzs.system.service.ISseService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okio.Buffer;
import okio.BufferedSink;
import okio.GzipSink;
import okio.Okio;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
@Api(tags = "任务查询")
@RestController
@RequestMapping("/mj")
@RequiredArgsConstructor
@Slf4j
public class FaceController {
@Value("${chat.apiKey}")
private String apiKey;
@Value("${chat.apiHost}")
private String apiHost;
@Autowired
private IChatService chatService;
@Autowired
private ISseService sseService;
@ApiOperation(value = "换脸")
@PostMapping("/insight-face/swap")
public String insightFace(@RequestBody InsightFace insightFace) {
// 查询是否是付费用户
sseService.checkUserGrade();
// 扣除接口费用
chatService.mjTaskDeduct("换脸", OpenAIConst.MJ_COST_TYPE2);
OkHttpClient client = new OkHttpClient.Builder()
.connectTimeout(30, TimeUnit.SECONDS) // 连接超时时间
.writeTimeout(30, TimeUnit.SECONDS) // 写入超时时间
.readTimeout(30, TimeUnit.SECONDS) // 读取超时时间
.build();
// 创建一个Request对象来配置你的请求
// 创建请求体这里使用JSON作为媒体类型
String jsonStr = JSONUtil.toJsonStr(insightFace);
MediaType JSON = MediaType.get("application/json; charset=utf-8");
okhttp3.RequestBody body = okhttp3.RequestBody.create(jsonStr, JSON);
Buffer buffer = new Buffer();
GzipSink gzipSink = new GzipSink(buffer);
BufferedSink gzipBufferedSink = Okio.buffer(gzipSink);
try {
body.writeTo(gzipBufferedSink);
gzipBufferedSink.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
// 创建POST请求
Request request = new Request.Builder()
.header("mj-api-secret", apiKey)
.header("Content-Encoding", "gzip")
.url(apiHost + "mj/insight-face/swap") // 替换为你的URL
.post(body)
.build();
try (Response response = client.newCall(request).execute()) {
if (!response.isSuccessful()) throw new IOException("Unexpected code " + response);
if (response.body() != null) {
return response.body().string();
}
} catch (IOException e) {
log.error("换脸失败! {}", e.getMessage());
}
return null;
}
}

View File

@@ -1,14 +1,20 @@
package com.xmzs.midjourney.controller;
import cn.hutool.core.comparator.CompareUtil;
import cn.hutool.json.JSONUtil;
import com.xmzs.midjourney.dto.SubmitImagineDTO;
import com.xmzs.midjourney.dto.TaskConditionDTO;
import com.xmzs.midjourney.loadbalancer.DiscordLoadBalancer;
import com.xmzs.midjourney.result.SubmitResultVO;
import com.xmzs.midjourney.service.TaskStoreService;
import com.xmzs.midjourney.support.Task;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import io.swagger.annotations.ApiParam;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
@@ -16,6 +22,7 @@ import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.io.IOException;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
@@ -25,15 +32,35 @@ import java.util.Objects;
@RestController
@RequestMapping("/mj/task")
@RequiredArgsConstructor
@Slf4j
public class TaskController {
private final TaskStoreService taskStoreService;
private final DiscordLoadBalancer discordLoadBalancer;
@Value("${chat.apiKey}")
private String apiKey;
@Value("${chat.apiHost}")
private String apiHost;
@ApiOperation(value = "指定ID获取任务")
@GetMapping("/{id}/fetch")
public Task fetch(@ApiParam(value = "任务ID") @PathVariable String id) {
return this.taskStoreService.get(id);
}
public String fetch(@ApiParam(value = "任务ID") @PathVariable String id) {
OkHttpClient client = new OkHttpClient();
// 创建一个Request对象来配置你的请求
Request request = new Request.Builder()
.header("mj-api-secret", apiKey) // 设置Authorization header
.url(apiHost+"mj/task/" + id + "/fetch")
.build();
try (Response response = client.newCall(request).execute()) {
if (!response.isSuccessful()) throw new IOException("Unexpected code " + response);
if (response.body() != null) {
return response.body().string();
}
} catch (IOException e) {
log.error("任务:{}查询失败:{}",id,e.getMessage());
}
return null;
}
@ApiOperation(value = "查询任务队列")
@GetMapping("/queue")

View File

@@ -0,0 +1,23 @@
package com.xmzs.midjourney.domain;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.EqualsAndHashCode;
import java.io.Serializable;
/**
* @author WangLe
*/
@Data
@ApiModel("Discord账号")
public class InsightFace implements Serializable {
/**本人头像json*/
@ApiModelProperty("本人头像json")
private String sourceBase64;
/**明星头像json*/
@ApiModelProperty("明星头像json")
private String targetBase64;
}

View File

@@ -0,0 +1,21 @@
package com.xmzs.midjourney.dto;
import com.xmzs.midjourney.enums.TaskAction;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.EqualsAndHashCode;
@Data
@ApiModel("变化任务提交参数")
public class SubmitActionDTO {
private String customId;
private String taskId;
private String state;
private String notifyHook;
}

View File

@@ -1,32 +1,152 @@
package com.xmzs.midjourney.support;
import cn.hutool.core.text.CharSequenceUtil;
import com.xmzs.midjourney.Constants;
import com.xmzs.midjourney.ProxyProperties;
import com.xmzs.common.chat.constant.OpenAIConst;
import com.xmzs.common.core.exception.ServiceException;
import com.xmzs.system.service.IChatService;
import com.xmzs.system.service.ISseService;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.*;
import org.springframework.stereotype.Component;
import org.springframework.util.StreamUtils;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.servlet.HandlerInterceptor;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Enumeration;
import java.util.Objects;
@Component
@RequiredArgsConstructor
public class ApiAuthorizeInterceptor implements HandlerInterceptor {
private final ProxyProperties properties;
private static final Logger log = LoggerFactory.getLogger(ApiAuthorizeInterceptor.class);
private static final String API_SECRET_HEADER = "mj-api-secret";
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
if (CharSequenceUtil.isBlank(this.properties.getApiSecret())) {
return true;
}
String apiSecret = request.getHeader(Constants.API_SECRET_HEADER_NAME);
boolean authorized = CharSequenceUtil.equals(apiSecret, this.properties.getApiSecret());
if (!authorized) {
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
}
return authorized;
}
@Value("${chat.apiKey}")
private String API_SECRET_VALUE;
@Value("${chat.apiHost}")
private String apiHost;
@Autowired
private RestTemplate restTemplate;
@Autowired
private IChatService chatService;
@Autowired
private ISseService sseService;
@Override
public boolean preHandle(@NotNull HttpServletRequest request, @NotNull HttpServletResponse response, @NotNull Object handler) {
// 判断是否是MidJourney的请求
if (isMidJourneyRequest(request)) {
try {
// 处理请求,例如费用扣除和任务查询
processRequest(request);
// 转发请求到目标服务器
forwardRequest(request, response);
} catch (Exception e) {
// 记录错误日志,包括异常堆栈信息
log.error("转发请求时发生错误", e);
// 设置HTTP状态码和响应信息
response.setStatus(HttpStatus.INTERNAL_SERVER_ERROR.value());
response.setCharacterEncoding(StandardCharsets.UTF_8.name());
response.setContentType(MediaType.APPLICATION_JSON_VALUE);
try {
// 向客户端返回错误信息
response.getWriter().write("绘图失败!" + e.getMessage());
} catch (Exception ex) {
log.error("设置错误响应时发生错误", ex);
}
// 中断请求处理流程
return false;
}
// 中断正常的请求处理流程,因为请求已被转发
return false;
}
// 如果不是MidJourney的请求则继续正常处理
return true;
}
private boolean isMidJourneyRequest(HttpServletRequest request) {
String uri = request.getRequestURI();
return uri.startsWith("/mj") &&
!uri.matches(".*/\\d+/fetch") &&
!uri.matches("/mj/insight-face/swap") &&
!uri.matches("/mj/submit/action");
}
private void processRequest(HttpServletRequest request) {
// 处理付费用户的请求,包括费用扣除和任务查询
sseService.checkUserGrade();
String uri = request.getRequestURI();
if (uri.matches("/mj/submit/describe") || uri.matches("/mj/submit/shorten")) {
chatService.mjTaskDeduct(uri.endsWith("describe") ? "图生文" : "prompt分析", OpenAIConst.MJ_COST_TYPE2);
} else if (uri.endsWith("image-seed") || uri.endsWith("list-by-condition")) {
chatService.mjTaskDeduct(uri.endsWith("image-seed") ? "获取种子" : "任务查询", OpenAIConst.MJ_COST_TYPE3);
} else if (uri.matches("/mj/submit/.*")) {
chatService.mjTaskDeduct("文生图", OpenAIConst.MJ_COST_TYPE1);
}
}
private void forwardRequest(HttpServletRequest request, HttpServletResponse response) throws Exception {
String targetUrl = buildTargetUrl(request);
HttpEntity<String> entity = new HttpEntity<>(readRequestBody(request), copyHeaders(request));
HttpMethod method = HttpMethod.valueOf(request.getMethod());
ResponseEntity<byte[]> responseEntity = restTemplate.exchange(targetUrl, method, entity, byte[].class);
copyResponseBack(response, responseEntity);
}
private String buildTargetUrl(HttpServletRequest request) {
String uri = request.getRequestURI();
String queryString = request.getQueryString();
log.info("Forwarding URL: {}", uri);
return apiHost + uri + (queryString != null ? "?" + queryString : "");
}
private HttpHeaders copyHeaders(HttpServletRequest request) {
HttpHeaders headers = new HttpHeaders();
headers.set(API_SECRET_HEADER, API_SECRET_VALUE);
Enumeration<String> headerNames = request.getHeaderNames();
while (headerNames.hasMoreElements()) {
String headerName = headerNames.nextElement();
if (!headerName.equalsIgnoreCase(API_SECRET_HEADER) &&
!headerName.equalsIgnoreCase(HttpHeaders.CONTENT_LENGTH) &&
!headerName.equalsIgnoreCase(HttpHeaders.AUTHORIZATION)) {
headers.set(headerName, request.getHeader(headerName));
}
}
return headers;
}
private String readRequestBody(HttpServletRequest request) throws Exception {
if (request.getContentLengthLong() > 0) {
try (InputStream inputStream = request.getInputStream()) {
return StreamUtils.copyToString(inputStream, StandardCharsets.UTF_8);
}
}
return "";
}
private void copyResponseBack(HttpServletResponse response, ResponseEntity<byte[]> responseEntity) throws Exception {
HttpHeaders responseHeaders = responseEntity.getHeaders();
responseHeaders.forEach((key, values) -> {
if (!key.equalsIgnoreCase(API_SECRET_HEADER)) {
response.addHeader(key, String.join(",", values));
}
});
// 设置响应内容类型为UTF-8防止乱码
response.setContentType(MediaType.APPLICATION_JSON_VALUE);
response.setCharacterEncoding(StandardCharsets.UTF_8.name());
HttpStatus status = HttpStatus.resolve(responseEntity.getStatusCode().value());
response.setStatus(Objects.requireNonNullElse(status, HttpStatus.INTERNAL_SERVER_ERROR).value());
if (responseEntity.getBody() != null) {
StreamUtils.copy(responseEntity.getBody(), response.getOutputStream());
}
}
}

View File

@@ -9,8 +9,6 @@ import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.ViewControllerRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
@Configuration
public class WebMvcConfig implements WebMvcConfigurer {
@Resource
@@ -27,7 +25,7 @@ public class WebMvcConfig implements WebMvcConfigurer {
public void addInterceptors(InterceptorRegistry registry) {
if (CharSequenceUtil.isNotBlank(this.properties.getApiSecret())) {
registry.addInterceptor(this.apiAuthorizeInterceptor)
.addPathPatterns("/submit/**", "/task/**", "/account/**");
.addPathPatterns("/mj/submit/**", "/mj/task/**", "/mj/account/**");
}
}