-
Notifications
You must be signed in to change notification settings - Fork 32
/
gapic.async_client.generate_content.py
48 lines (39 loc) · 1.56 KB
/
gapic.async_client.generate_content.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import asyncio
import os
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform_v1 import (
Content,
GenerateContentRequest,
GenerationConfig,
Part,
PredictionServiceAsyncClient,
)
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor
from openinference.instrumentation.vertexai import VertexAIInstrumentor
endpoint = "http://127.0.0.1:4317"
tracer_provider = trace_sdk.TracerProvider()
tracer_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))
tracer_provider.add_span_processor(SimpleSpanProcessor(ConsoleSpanExporter()))
VertexAIInstrumentor().instrument(tracer_provider=tracer_provider)
location = "us-central1"
project = os.environ["CLOUD_ML_PROJECT_ID"]
model = "gemini-1.5-flash"
request = GenerateContentRequest(
{
"contents": [Content(dict(role="user", parts=[Part(dict(text="Write a haiku."))]))],
"model": f"projects/{project}/locations/{location}/publishers/google/models/{model}",
"generation_config": GenerationConfig(dict(max_output_tokens=20)),
}
)
async def main() -> None:
client: PredictionServiceAsyncClient = initializer.global_config.create_client(
client_class=PredictionServiceAsyncClient,
location_override=location,
prediction_client=True,
)
response = await client.generate_content(request)
print(response)
if __name__ == "__main__":
asyncio.run(main())