108 lines
4.0 KiB
Python
108 lines
4.0 KiB
Python
# Copyright (c) 2023-present Plane Software, Inc. and contributors
|
|
# SPDX-License-Identifier: AGPL-3.0-only
|
|
# See the LICENSE file for details.
|
|
|
|
from openai import OpenAI
|
|
|
|
from rest_framework import status
|
|
from rest_framework.response import Response
|
|
|
|
from plane.app.permissions import ROLE, allow_permission
|
|
from plane.app.serializers import WorkspaceAISettingsSerializer
|
|
from plane.db.models import Workspace, WorkspaceAICredential, WorkspaceAISettings
|
|
from plane.license.utils.encryption import decrypt_data
|
|
from plane.utils.exception_logger import log_exception
|
|
|
|
from .base import BaseAPIView
|
|
|
|
|
|
class WorkspaceAISettingsEndpoint(BaseAPIView):
|
|
def get_settings(self, slug):
|
|
workspace = Workspace.objects.get(slug=slug)
|
|
ai_settings, _ = WorkspaceAISettings.objects.get_or_create(workspace=workspace)
|
|
return workspace, ai_settings
|
|
|
|
@allow_permission(allowed_roles=[ROLE.ADMIN], level="WORKSPACE")
|
|
def get(self, request, slug):
|
|
workspace, ai_settings = self.get_settings(slug)
|
|
serializer = WorkspaceAISettingsSerializer(ai_settings, context={"workspace": workspace})
|
|
return Response(serializer.data, status=status.HTTP_200_OK)
|
|
|
|
@allow_permission(allowed_roles=[ROLE.ADMIN], level="WORKSPACE")
|
|
def patch(self, request, slug):
|
|
workspace, ai_settings = self.get_settings(slug)
|
|
serializer = WorkspaceAISettingsSerializer(
|
|
ai_settings,
|
|
data=request.data,
|
|
partial=True,
|
|
context={"workspace": workspace},
|
|
)
|
|
if serializer.is_valid():
|
|
serializer.save()
|
|
return Response(serializer.data, status=status.HTTP_200_OK)
|
|
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
|
|
|
|
|
class WorkspaceAISettingsTestConnectionEndpoint(BaseAPIView):
|
|
@allow_permission(allowed_roles=[ROLE.ADMIN], level="WORKSPACE")
|
|
def post(self, request, slug):
|
|
workspace = Workspace.objects.get(slug=slug)
|
|
ai_settings, _ = WorkspaceAISettings.objects.get_or_create(workspace=workspace)
|
|
credential = WorkspaceAICredential.objects.filter(
|
|
workspace=workspace,
|
|
provider=ai_settings.provider,
|
|
is_active=True,
|
|
).first()
|
|
|
|
if not credential or not credential.encrypted_api_key:
|
|
return Response(
|
|
{
|
|
"ok": False,
|
|
"code": "missing_api_key",
|
|
"error": "OpenAI API key is not configured for this workspace.",
|
|
},
|
|
status=status.HTTP_400_BAD_REQUEST,
|
|
)
|
|
|
|
api_key = decrypt_data(credential.encrypted_api_key)
|
|
if not api_key:
|
|
return Response(
|
|
{
|
|
"ok": False,
|
|
"code": "invalid_encrypted_key",
|
|
"error": "OpenAI API key could not be decrypted.",
|
|
},
|
|
status=status.HTTP_400_BAD_REQUEST,
|
|
)
|
|
|
|
try:
|
|
client = OpenAI(api_key=api_key)
|
|
client.models.retrieve(ai_settings.structuring_model)
|
|
return Response(
|
|
{
|
|
"ok": True,
|
|
"provider": ai_settings.provider,
|
|
"model": ai_settings.structuring_model,
|
|
},
|
|
status=status.HTTP_200_OK,
|
|
)
|
|
except Exception as exc:
|
|
log_exception(exc)
|
|
error_type = exc.__class__.__name__
|
|
status_code = status.HTTP_400_BAD_REQUEST
|
|
error_code = "openai_connection_failed"
|
|
if error_type == "AuthenticationError":
|
|
error_code = "invalid_api_key"
|
|
elif error_type == "RateLimitError":
|
|
error_code = "rate_limited"
|
|
status_code = status.HTTP_429_TOO_MANY_REQUESTS
|
|
|
|
return Response(
|
|
{
|
|
"ok": False,
|
|
"code": error_code,
|
|
"error": "OpenAI connection check failed.",
|
|
},
|
|
status=status_code,
|
|
)
|