Skip to content

Commit 95b107c

Browse files
authored
feat: Change the Metadata SDK _Context class to an external class (#1519)
* feat: Change the Metadata SDK _Context class to an external class * Add base schema class for context * Add additional context schema types * Add additional context schema types * Add create method to Context. * Fix unit test failure. * add unit tests * fix lint issue * Add Context to root __init__. * correct import path
1 parent fd55daf commit 95b107c

11 files changed

+575
-46
lines changed

google/cloud/aiplatform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
ExperimentRun = metadata.experiment_run_resource.ExperimentRun
9696
Artifact = metadata.artifact.Artifact
9797
Execution = metadata.execution.Execution
98+
Context = metadata.context.Context
9899

99100

100101
__all__ = (

google/cloud/aiplatform/metadata/context.py

Lines changed: 152 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import proto
2121

22+
from google.auth import credentials as auth_credentials
23+
2224
from google.cloud.aiplatform import base
2325
from google.cloud.aiplatform import utils
2426
from google.cloud.aiplatform.metadata import utils as metadata_utils
@@ -31,10 +33,11 @@
3133
)
3234
from google.cloud.aiplatform.metadata import artifact
3335
from google.cloud.aiplatform.metadata import execution
36+
from google.cloud.aiplatform.metadata import metadata_store
3437
from google.cloud.aiplatform.metadata import resource
3538

3639

37-
class _Context(resource._Resource):
40+
class Context(resource._Resource):
3841
"""Metadata Context resource for Vertex AI"""
3942

4043
_resource_noun = "contexts"
@@ -81,6 +84,153 @@ def get_artifacts(self) -> List[artifact.Artifact]:
8184
credentials=self.credentials,
8285
)
8386

87+
@classmethod
88+
def create(
89+
cls,
90+
schema_title: str,
91+
*,
92+
resource_id: Optional[str] = None,
93+
display_name: Optional[str] = None,
94+
schema_version: Optional[str] = None,
95+
description: Optional[str] = None,
96+
metadata: Optional[Dict] = None,
97+
metadata_store_id: Optional[str] = "default",
98+
project: Optional[str] = None,
99+
location: Optional[str] = None,
100+
credentials: Optional[auth_credentials.Credentials] = None,
101+
) -> "Context":
102+
"""Creates a new Metadata Context.
103+
104+
Args:
105+
schema_title (str):
106+
Required. schema_title identifies the schema title used by the Context.
107+
Please reference https://cloud.google.com/vertex-ai/docs/ml-metadata/system-schemas.
108+
resource_id (str):
109+
Optional. The <resource_id> portion of the Context name with
110+
the format. This is globally unique in a metadataStore:
111+
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/Contexts/<resource_id>.
112+
display_name (str):
113+
Optional. The user-defined name of the Context.
114+
schema_version (str):
115+
Optional. schema_version specifies the version used by the Context.
116+
If not set, defaults to use the latest version.
117+
description (str):
118+
Optional. Describes the purpose of the Context to be created.
119+
metadata (Dict):
120+
Optional. Contains the metadata information that will be stored in the Context.
121+
metadata_store_id (str):
122+
Optional. The <metadata_store_id> portion of the resource name with
123+
the format:
124+
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/Contexts/<resource_id>
125+
If not provided, the MetadataStore's ID will be set to "default".
126+
project (str):
127+
Optional. Project used to create this Context. Overrides project set in
128+
aiplatform.init.
129+
location (str):
130+
Optional. Location used to create this Context. Overrides location set in
131+
aiplatform.init.
132+
credentials (auth_credentials.Credentials):
133+
Optional. Custom credentials used to create this Context. Overrides
134+
credentials set in aiplatform.init.
135+
136+
Returns:
137+
Context: Instantiated representation of the managed Metadata Context.
138+
"""
139+
return cls._create(
140+
resource_id=resource_id,
141+
schema_title=schema_title,
142+
display_name=display_name,
143+
schema_version=schema_version,
144+
description=description,
145+
metadata=metadata,
146+
metadata_store_id=metadata_store_id,
147+
project=project,
148+
location=location,
149+
credentials=credentials,
150+
)
151+
152+
# TODO() refactor code to move _create to _Resource class.
153+
@classmethod
154+
def _create(
155+
cls,
156+
resource_id: str,
157+
schema_title: str,
158+
display_name: Optional[str] = None,
159+
schema_version: Optional[str] = None,
160+
description: Optional[str] = None,
161+
metadata: Optional[Dict] = None,
162+
metadata_store_id: Optional[str] = "default",
163+
project: Optional[str] = None,
164+
location: Optional[str] = None,
165+
credentials: Optional[auth_credentials.Credentials] = None,
166+
) -> "Context":
167+
"""Creates a new Metadata resource.
168+
169+
Args:
170+
resource_id (str):
171+
Required. The <resource_id> portion of the resource name with
172+
the format:
173+
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>.
174+
schema_title (str):
175+
Required. schema_title identifies the schema title used by the resource.
176+
display_name (str):
177+
Optional. The user-defined name of the resource.
178+
schema_version (str):
179+
Optional. schema_version specifies the version used by the resource.
180+
If not set, defaults to use the latest version.
181+
description (str):
182+
Optional. Describes the purpose of the resource to be created.
183+
metadata (Dict):
184+
Optional. Contains the metadata information that will be stored in the resource.
185+
metadata_store_id (str):
186+
The <metadata_store_id> portion of the resource name with
187+
the format:
188+
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
189+
If not provided, the MetadataStore's ID will be set to "default".
190+
project (str):
191+
Project used to create this resource. Overrides project set in
192+
aiplatform.init.
193+
location (str):
194+
Location used to create this resource. Overrides location set in
195+
aiplatform.init.
196+
credentials (auth_credentials.Credentials):
197+
Custom credentials used to create this resource. Overrides
198+
credentials set in aiplatform.init.
199+
200+
Returns:
201+
resource (_Resource):
202+
Instantiated representation of the managed Metadata resource.
203+
204+
"""
205+
api_client = cls._instantiate_client(location=location, credentials=credentials)
206+
207+
parent = utils.full_resource_name(
208+
resource_name=metadata_store_id,
209+
resource_noun=metadata_store._MetadataStore._resource_noun,
210+
parse_resource_name_method=metadata_store._MetadataStore._parse_resource_name,
211+
format_resource_name_method=metadata_store._MetadataStore._format_resource_name,
212+
project=project,
213+
location=location,
214+
)
215+
216+
resource = cls._create_resource(
217+
client=api_client,
218+
parent=parent,
219+
resource_id=resource_id,
220+
schema_title=schema_title,
221+
display_name=display_name,
222+
schema_version=schema_version,
223+
description=description,
224+
metadata=metadata,
225+
)
226+
227+
self = cls._empty_constructor(
228+
project=project, location=location, credentials=credentials
229+
)
230+
self._gca_resource = resource
231+
232+
return self
233+
84234
@classmethod
85235
def _create_resource(
86236
cls,
@@ -147,7 +297,7 @@ def _list_resources(
147297
)
148298
return client.list_contexts(request=list_request)
149299

150-
def add_context_children(self, contexts: List["_Context"]):
300+
def add_context_children(self, contexts: List["Context"]):
151301
"""Adds the provided contexts as children of this context.
152302
153303
Args:

google/cloud/aiplatform/metadata/experiment_resources.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,13 @@ def __init__(
119119
)
120120

121121
with _SetLoggerLevel(resource):
122-
experiment_context = context._Context(**metadata_args)
122+
experiment_context = context.Context(**metadata_args)
123123
self._validate_experiment_context(experiment_context)
124124

125125
self._metadata_context = experiment_context
126126

127127
@staticmethod
128-
def _validate_experiment_context(experiment_context: context._Context):
128+
def _validate_experiment_context(experiment_context: context.Context):
129129
"""Validates this context is an experiment context.
130130
131131
Args:
@@ -146,7 +146,7 @@ def _validate_experiment_context(experiment_context: context._Context):
146146
)
147147

148148
@staticmethod
149-
def _is_tensorboard_experiment(context: context._Context) -> bool:
149+
def _is_tensorboard_experiment(context: context.Context) -> bool:
150150
"""Returns True if Experiment is a Tensorboard Experiment created by CustomJob."""
151151
return constants.TENSORBOARD_CUSTOM_JOB_EXPERIMENT_FIELD in context.metadata
152152

@@ -192,7 +192,7 @@ def create(
192192
)
193193

194194
with _SetLoggerLevel(resource):
195-
experiment_context = context._Context._create(
195+
experiment_context = context.Context._create(
196196
resource_id=experiment_name,
197197
display_name=experiment_name,
198198
description=description,
@@ -248,7 +248,7 @@ def get_or_create(
248248
)
249249

250250
with _SetLoggerLevel(resource):
251-
experiment_context = context._Context.get_or_create(
251+
experiment_context = context.Context.get_or_create(
252252
resource_id=experiment_name,
253253
display_name=experiment_name,
254254
description=description,
@@ -303,7 +303,7 @@ def list(
303303
)
304304

305305
with _SetLoggerLevel(resource):
306-
experiment_contexts = context._Context.list(
306+
experiment_contexts = context.Context.list(
307307
filter=filter_str,
308308
project=project,
309309
location=location,
@@ -341,7 +341,7 @@ def delete(self, *, delete_backing_tensorboard_runs: bool = False):
341341
runs under this experiment that we used to store time series metrics.
342342
"""
343343

344-
experiment_runs = _SUPPORTED_LOGGABLE_RESOURCES[context._Context][
344+
experiment_runs = _SUPPORTED_LOGGABLE_RESOURCES[context.Context][
345345
constants.SYSTEM_EXPERIMENT_RUN
346346
].list(experiment=self)
347347
for experiment_run in experiment_runs:
@@ -380,11 +380,11 @@ def get_data_frame(self) -> "pd.DataFrame": # noqa: F821
380380

381381
filter_str = metadata_utils._make_filter_string(
382382
schema_title=sorted(
383-
list(_SUPPORTED_LOGGABLE_RESOURCES[context._Context].keys())
383+
list(_SUPPORTED_LOGGABLE_RESOURCES[context.Context].keys())
384384
),
385385
parent_contexts=[self._metadata_context.resource_name],
386386
)
387-
contexts = context._Context.list(filter_str, **service_request_args)
387+
contexts = context.Context.list(filter_str, **service_request_args)
388388

389389
filter_str = metadata_utils._make_filter_string(
390390
schema_title=list(
@@ -398,7 +398,7 @@ def get_data_frame(self) -> "pd.DataFrame": # noqa: F821
398398
rows = []
399399
for metadata_context in contexts:
400400
row_dict = (
401-
_SUPPORTED_LOGGABLE_RESOURCES[context._Context][
401+
_SUPPORTED_LOGGABLE_RESOURCES[context.Context][
402402
metadata_context.schema_title
403403
]
404404
._query_experiment_row(metadata_context)
@@ -568,7 +568,7 @@ class _VertexResourceWithMetadata(NamedTuple):
568568
"""Represents a resource coupled with it's metadata representation"""
569569

570570
resource: base.VertexAiResourceNoun
571-
metadata: Union[artifact.Artifact, execution.Execution, context._Context]
571+
metadata: Union[artifact.Artifact, execution.Execution, context.Context]
572572

573573

574574
class _ExperimentLoggableSchema(NamedTuple):
@@ -581,7 +581,7 @@ class _ExperimentLoggableSchema(NamedTuple):
581581
"""
582582

583583
title: str
584-
type: Union[Type[context._Context], Type[execution.Execution]] = context._Context
584+
type: Union[Type[context.Context], Type[execution.Execution]] = context.Context
585585

586586

587587
class _ExperimentLoggable(abc.ABC):
@@ -618,7 +618,7 @@ class PipelineJob(..., experiment_loggable_schemas=
618618
_SUPPORTED_LOGGABLE_RESOURCES[schema.type][schema.title] = cls
619619

620620
@abc.abstractmethod
621-
def _get_context(self) -> context._Context:
621+
def _get_context(self) -> context.Context:
622622
"""Should return the metadata context that represents this resource.
623623
624624
The subclass should enforce this context exists.
@@ -631,7 +631,7 @@ def _get_context(self) -> context._Context:
631631
@classmethod
632632
@abc.abstractmethod
633633
def _query_experiment_row(
634-
cls, node: Union[context._Context, execution.Execution]
634+
cls, node: Union[context.Context, execution.Execution]
635635
) -> _ExperimentRow:
636636
"""Should return parameters and metrics for this resource as a run row.
637637
@@ -716,6 +716,6 @@ def _associate_to_experiment(self, experiment: Union[str, Experiment]):
716716
# Context -> 'system.ExperimentRun' -> aiplatform.ExperimentRun
717717
# Execution -> 'system.Run' -> aiplatform.ExperimentRun
718718
_SUPPORTED_LOGGABLE_RESOURCES: Dict[
719-
Union[Type[context._Context], Type[execution.Execution]],
719+
Union[Type[context.Context], Type[execution.Execution]],
720720
Dict[str, _ExperimentLoggable],
721-
] = {execution.Execution: dict(), context._Context: dict()}
721+
] = {execution.Execution: dict(), context.Context: dict()}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy