Skip to content

Commit

Permalink
add AWS bedlock components
Browse files Browse the repository at this point in the history
  • Loading branch information
fumito-ito committed Mar 23, 2024
1 parent 9c0c1e6 commit 1a4f4e1
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 0 deletions.
36 changes: 36 additions & 0 deletions Package.resolved
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
{
"pins" : [
{
"identity" : "aws-crt-swift",
"kind" : "remoteSourceControl",
"location" : "https://github.com/awslabs/aws-crt-swift",
"state" : {
"revision" : "0d0a0cf2e2cb780ceeceac190b4ede94f4f96902",
"version" : "0.26.0"
}
},
{
"identity" : "aws-sdk-swift",
"kind" : "remoteSourceControl",
"location" : "https://github.com/awslabs/aws-sdk-swift",
"state" : {
"revision" : "7b0d55676eb896e472662922e80ebe30ae3ca8ed",
"version" : "0.39.0"
}
},
{
"identity" : "collectionconcurrencykit",
"kind" : "remoteSourceControl",
Expand All @@ -18,6 +36,15 @@
"version" : "1.8.1"
}
},
{
"identity" : "smithy-swift",
"kind" : "remoteSourceControl",
"location" : "https://github.com/smithy-lang/smithy-swift",
"state" : {
"revision" : "0267f0c649558cbc1323bbc5a0c8b2403239655b",
"version" : "0.44.0"
}
},
{
"identity" : "sourcekitten",
"kind" : "remoteSourceControl",
Expand All @@ -36,6 +63,15 @@
"version" : "1.2.3"
}
},
{
"identity" : "swift-log",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-log.git",
"state" : {
"revision" : "e97a6fcb1ab07462881ac165fdbb37f067e205d5",
"version" : "1.5.4"
}
},
{
"identity" : "swift-syntax",
"kind" : "remoteSourceControl",
Expand Down
7 changes: 7 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ let package = Package(
],
dependencies: [
.package(url: "https://github.com/realm/SwiftLint", .upToNextMajor(from: "0.54.0")),
.package(url: "https://github.com/awslabs/aws-sdk-swift", from: "0.32.0"),
],
targets: [
// Targets are the basic building blocks of a package, defining a module or a test suite.
Expand All @@ -28,5 +29,11 @@ let package = Package(
.testTarget(
name: "AnthropicSwiftSDKTests",
dependencies: ["AnthropicSwiftSDK"]),
.target(
name: "AnthropicSwiftSDK-Bedrock",
dependencies: [
"AnthropicSwiftSDK",
.product(name: "AWSBedrockRuntime", package: "aws-sdk-swift")
])
]
)
18 changes: 18 additions & 0 deletions Sources/AnthropicSwiftSDK-Bedrock/BedrockAnthropicClient.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//
// BedrockAnthropicClient.swift
//
//
// Created by 伊藤史 on 2024/03/22.
//

import Foundation
import AWSBedrockRuntime
import AnthropicSwiftSDK

public final class BedrockAnthropicClient {
public let messages: AnthropicSwiftSDK_Bedrock.Messages

init(client: BedrockRuntimeClient, model: Model) {
self.messages = .init(client: client, model: model)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//
// AnthropicSwiftSDK_Model+Extension.swift
//
//
// Created by 伊藤史 on 2024/03/23.
//

import Foundation
import AnthropicSwiftSDK

extension AnthropicSwiftSDK.Model {
/// Model name for Amazon Bedrock
///
/// for more detail, see https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
var bedrockModelName: String? {
switch self {
case .claude_3_Opus:
return nil
case .claude_3_Sonnet:
return "anthropic.claude-3-sonnet-20240229-v1:0"
case .claude_3_Haiku:
return "anthropic.claude-3-haiku-20240307-v1:0"
case let .custom(modelName):
return modelName
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//
// BedrockClient+Extension.swift
//
//
// Created by 伊藤史 on 2024/03/22.
//

import Foundation
import AWSBedrockRuntime
import AnthropicSwiftSDK

public extension BedrockRuntimeClient {
static func useAnthropic(_ client: BedrockRuntimeClient, model: Model) -> BedrockAnthropicClient {
BedrockAnthropicClient(client: client, model: model)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//
// InvokeModelInput+Extension.swift
//
//
// Created by 伊藤史 on 2024/03/23.
//

import Foundation
import AnthropicSwiftSDK
import AWSBedrockRuntime

extension InvokeModelInput {
init(accept: String, request: MessagesRequest, contentType: String) throws {
let data = try anthropicJSONEncoder.encode(request)

self.init(
accept: accept,
body: data,
contentType: contentType,
modelId: request.model.bedrockModelName
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//
// InvokeModelWithResponseStreamInput+Extension.swift
//
//
// Created by 伊藤史 on 2024/03/23.
//

import Foundation
import AnthropicSwiftSDK
import AWSBedrockRuntime

extension InvokeModelWithResponseStreamInput {
init(accept: String, request: MessagesRequest, contentType: String) throws {
let data = try anthropicJSONEncoder.encode(request)

self.init(
accept: accept,
body: data,
contentType: contentType,
modelId: request.model.bedrockModelName
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//
// MessagesResponse+Extension.swift
//
//
// Created by 伊藤史 on 2024/03/23.
//

import Foundation
import AnthropicSwiftSDK
import AWSBedrockRuntime

extension MessagesResponse {
init (from invokeModelOutput: InvokeModelOutput) throws {
guard let data = invokeModelOutput.body else {
fatalError()
}

self = try anthropicJSONDecoder.decode(MessagesResponse.self, from: data)
}
}
116 changes: 116 additions & 0 deletions Sources/AnthropicSwiftSDK-Bedrock/Messages.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
//
// Messages.swift
//
//
// Created by 伊藤史 on 2024/03/22.
//

import Foundation
import AnthropicSwiftSDK
import AWSBedrockRuntime

public struct Messages {
private let acceptContentType = "application/json"
private let requestContentType = "application/json"
private let client: BedrockRuntimeClient

public let model: Model

init(client: BedrockRuntimeClient, model: Model) {
self.client = client
self.model = model
}

public func createMessage(
_ messages: [Message],
model: Model = .claude_3_Haiku,
system: String? = nil,
maxTokens: Int,
metaData: MetaData? = nil,
stopSequence: [String]? = nil,
temperature: Double? = nil,
topP: Int? = nil,
topK: Int? = nil
) async throws -> MessagesResponse {
// In the inference call, fill the body field with a JSON object that conforms the type call you want to make [Anthropic Claude Messages API](https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html).
let requestBody = MessagesRequest(
model: model,
messages: messages,
system: system,
maxTokens: maxTokens,
metaData: metaData,
stopSequences: stopSequence,
stream: false,
temperature: temperature,
topP: topP,
topK: topK
)

let input = try InvokeModelInput(
accept: acceptContentType,
request: requestBody,
contentType: requestContentType
)

let response = try await client.invokeModel(input: input)

return try MessagesResponse(from: response)
}

public func streamMessage(
_ messages: [Message],
model: Model = .claude_3_Haiku,
system: String? = nil,
maxTokens: Int,
metaData: MetaData? = nil,
stopSequence: [String]? = nil,
temperature: Double? = nil,
topP: Int? = nil,
topK: Int? = nil
) async throws -> AsyncThrowingStream<StreamingResponse, Error> {
// In the inference call, fill the body field with a JSON object that conforms the type call you want to make [Anthropic Claude Messages API](https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html).
let requestBody = MessagesRequest(
model: model,
messages: messages,
system: system,
maxTokens: maxTokens,
metaData: metaData,
stopSequences: stopSequence,
stream: true,
temperature: temperature,
topP: topP,
topK: topK
)

let data = try anthropicJSONEncoder.encode(requestBody)

let input = try InvokeModelWithResponseStreamInput(
accept: acceptContentType,
request: requestBody,
contentType: requestContentType
)

let response = try await client.invokeModelWithResponseStream(input: input)

guard let responseStream = response.body else {
fatalError()
}

return try await AnthropicStreamingParser.parse(stream: responseStream.map { try $0.toString() })
}
}

extension BedrockRuntimeClientTypes.ResponseStream {
func toString() throws -> String {
guard case let .chunk(payload) = self else {
fatalError()
}

guard let data = payload.bytes,
let line = String(data: data, encoding: .utf8) else {
fatalError()
}

return line
}
}

0 comments on commit 1a4f4e1

Please sign in to comment.