@@ -148,6 +148,40 @@ def search_nearest_entities_mock():
148
148
yield search_nearest_entities_mock
149
149
150
150
151
+ @pytest .fixture
152
+ def transport_mock ():
153
+ with mock .patch (
154
+ "google.cloud.aiplatform_v1.services.feature_online_store_service.transports.grpc.FeatureOnlineStoreServiceGrpcTransport"
155
+ ) as transport :
156
+ transport .return_value = mock .MagicMock (autospec = True )
157
+ yield transport
158
+
159
+
160
+ @pytest .fixture
161
+ def grpc_insecure_channel_mock ():
162
+ import grpc
163
+
164
+ with mock .patch .object (grpc , "insecure_channel" , autospec = True ) as channel :
165
+ channel .return_value = mock .MagicMock (autospec = True )
166
+ yield channel
167
+
168
+
169
+ @pytest .fixture
170
+ def client_mock ():
171
+ with mock .patch (
172
+ "google.cloud.aiplatform_v1.services.feature_online_store_service.FeatureOnlineStoreServiceClient"
173
+ ) as client_mock :
174
+ yield client_mock
175
+
176
+
177
+ @pytest .fixture
178
+ def utils_client_with_override_mock ():
179
+ with mock .patch (
180
+ "google.cloud.aiplatform.utils.FeatureOnlineStoreClientWithOverride"
181
+ ) as client_mock :
182
+ yield client_mock
183
+
184
+
151
185
def fv_eq (
152
186
fv_to_check : FeatureView ,
153
187
name : str ,
@@ -428,6 +462,308 @@ def test_fetch_feature_values_optimized_no_endpoint(
428
462
FeatureView (_TEST_OPTIMIZED_FV2_PATH ).read (key = ["key1" ]).to_dict ()
429
463
430
464
465
+ def test_ffv_optimized_psc_with_no_connection_options_raises_error (
466
+ get_psc_optimized_fos_mock ,
467
+ get_optimized_fv_mock ,
468
+ ):
469
+ with pytest .raises (ValueError ) as excinfo :
470
+ FeatureView (_TEST_OPTIMIZED_FV1_PATH ).read (key = ["key1" ])
471
+
472
+ assert str (excinfo .value ) == (
473
+ "Use `connection_options` to specify an IP address. Required for optimized online store with private service connect."
474
+ )
475
+
476
+
477
+ def test_ffv_optimized_psc_with_no_connection_transport_raises_error (
478
+ get_psc_optimized_fos_mock ,
479
+ get_optimized_fv_mock ,
480
+ ):
481
+ with pytest .raises (ValueError ) as excinfo :
482
+ FeatureView (_TEST_OPTIMIZED_FV1_PATH ).read (
483
+ key = ["key1" ],
484
+ connection_options = fs_utils .ConnectionOptions (
485
+ host = "1.2.3.4" , transport = None
486
+ ),
487
+ )
488
+
489
+ assert str (excinfo .value ) == (
490
+ "Unsupported connection transport type, got transport: None"
491
+ )
492
+
493
+
494
+ def test_ffv_optimized_psc_with_bad_connection_transport_raises_error (
495
+ get_psc_optimized_fos_mock ,
496
+ get_optimized_fv_mock ,
497
+ ):
498
+ with pytest .raises (ValueError ) as excinfo :
499
+ FeatureView (_TEST_OPTIMIZED_FV1_PATH ).read (
500
+ key = ["key1" ],
501
+ connection_options = fs_utils .ConnectionOptions (
502
+ host = "1.2.3.4" , transport = "hi"
503
+ ),
504
+ )
505
+
506
+ assert str (excinfo .value ) == (
507
+ "Unsupported connection transport type, got transport: hi"
508
+ )
509
+
510
+
511
+ @pytest .mark .parametrize ("output_type" , ["dict" , "proto" ])
512
+ def test_ffv_optimized_psc (
513
+ get_psc_optimized_fos_mock ,
514
+ get_optimized_fv_mock ,
515
+ transport_mock ,
516
+ grpc_insecure_channel_mock ,
517
+ fetch_feature_values_mock ,
518
+ output_type ,
519
+ ):
520
+ rsp = FeatureView (_TEST_OPTIMIZED_FV1_PATH ).read (
521
+ key = ["key1" ],
522
+ connection_options = fs_utils .ConnectionOptions (
523
+ host = "1.2.3.4" ,
524
+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
525
+ ),
526
+ )
527
+
528
+ # Ensure that we create and use insecure channel to the target.
529
+ grpc_insecure_channel_mock .assert_called_once_with ("1.2.3.4:10002" )
530
+ transport_grpc_channel = transport_mock .call_args .kwargs ["channel" ]
531
+ assert transport_grpc_channel == grpc_insecure_channel_mock .return_value
532
+
533
+ if output_type == "dict" :
534
+ assert rsp .to_dict () == {
535
+ "features" : [{"name" : "key1" , "value" : {"string_value" : "value1" }}]
536
+ }
537
+ elif output_type == "proto" :
538
+ assert rsp .to_proto () == _TEST_FV_FETCH1
539
+
540
+
541
+ def test_same_connection_options_are_equal ():
542
+ opt1 = fs_utils .ConnectionOptions (
543
+ host = "1.1.1.1" ,
544
+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
545
+ )
546
+ opt2 = fs_utils .ConnectionOptions (
547
+ host = "1.1.1.1" ,
548
+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
549
+ )
550
+ assert opt1 == opt2
551
+
552
+
553
+ def test_different_host_in_connection_options_are_not_equal ():
554
+ opt1 = fs_utils .ConnectionOptions (
555
+ host = "1.1.1.2" ,
556
+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
557
+ )
558
+ opt2 = fs_utils .ConnectionOptions (
559
+ host = "1.1.1.1" ,
560
+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
561
+ )
562
+
563
+ assert opt1 != opt2
564
+
565
+
566
+ def test_bad_transport_in_compared_connection_options_raises_error ():
567
+ opt1 = fs_utils .ConnectionOptions (
568
+ host = "1.1.1.1" ,
569
+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
570
+ )
571
+ opt2 = fs_utils .ConnectionOptions (
572
+ host = "1.1.1.1" ,
573
+ transport = None ,
574
+ )
575
+
576
+ with pytest .raises (ValueError ) as excinfo :
577
+ assert opt1 != opt2
578
+
579
+ assert str (excinfo .value ) == (
580
+ "Transport 'ConnectionOptions.InsecureGrpcChannel()' cannot be compared to transport 'None'."
581
+ )
582
+
583
+
584
+ def test_bad_transport_in_connection_options_raises_error ():
585
+ opt1 = fs_utils .ConnectionOptions (
586
+ host = "1.1.1.1" ,
587
+ transport = None ,
588
+ )
589
+ opt2 = fs_utils .ConnectionOptions (
590
+ host = "1.1.1.1" ,
591
+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
592
+ )
593
+
594
+ with pytest .raises (ValueError ) as excinfo :
595
+ assert opt1 != opt2
596
+
597
+ assert str (excinfo .value ) == ("Unsupported transport supplied: None" )
598
+
599
+
600
+ def test_same_connection_options_have_same_hash ():
601
+ opt1 = fs_utils .ConnectionOptions (
602
+ host = "1.1.1.1" ,
603
+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
604
+ )
605
+ opt2 = fs_utils .ConnectionOptions (
606
+ host = "1.1.1.1" ,
607
+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
608
+ )
609
+
610
+ d = {}
611
+ d [opt1 ] = "hi"
612
+ assert d [opt2 ] == "hi"
613
+
614
+
615
+ @pytest .mark .parametrize (
616
+ "hosts" ,
617
+ [
618
+ ("1.1.1.1" , "1.1.1.2" ),
619
+ ("1.1.1.2" , "1.1.1.1" ),
620
+ ("10.0.0.1" , "9.9.9.9" ),
621
+ ],
622
+ )
623
+ def test_different_host_in_connection_options_have_different_hash (hosts ):
624
+ opt1 = fs_utils .ConnectionOptions (
625
+ host = hosts [0 ],
626
+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
627
+ )
628
+ opt2 = fs_utils .ConnectionOptions (
629
+ host = hosts [1 ],
630
+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
631
+ )
632
+
633
+ d = {}
634
+ d [opt1 ] = "hi"
635
+ assert opt2 not in d
636
+
637
+
638
+ @pytest .mark .parametrize (
639
+ "transports" ,
640
+ [
641
+ (fs_utils .ConnectionOptions .InsecureGrpcChannel (), None ),
642
+ (None , fs_utils .ConnectionOptions .InsecureGrpcChannel ()),
643
+ (None , "hi" ),
644
+ ("hi" , None ),
645
+ ],
646
+ )
647
+ def test_bad_transport_in_connection_options_have_different_hash (transports ):
648
+ opt1 = fs_utils .ConnectionOptions (
649
+ host = "1.1.1.1" ,
650
+ transport = transports [0 ],
651
+ )
652
+ opt2 = fs_utils .ConnectionOptions (
653
+ host = "1.1.1.1" ,
654
+ transport = transports [1 ],
655
+ )
656
+
657
+ d = {}
658
+ d [opt1 ] = "hi"
659
+ assert opt2 not in d
660
+
661
+
662
+ def test_diff_host_and_bad_transport_in_connection_options_have_different_hash ():
663
+ opt1 = fs_utils .ConnectionOptions (
664
+ host = "1.1.1.1" ,
665
+ transport = None ,
666
+ )
667
+ opt2 = fs_utils .ConnectionOptions (
668
+ host = "9.9.9.9" ,
669
+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
670
+ )
671
+
672
+ d = {}
673
+ d [opt1 ] = "hi"
674
+ assert opt2 not in d
675
+
676
+
677
+ def test_ffv_optimized_psc_reuse_client_for_same_connection_options_in_same_ffv (
678
+ get_psc_optimized_fos_mock ,
679
+ get_optimized_fv_mock ,
680
+ client_mock ,
681
+ transport_mock ,
682
+ grpc_insecure_channel_mock ,
683
+ fetch_feature_values_mock ,
684
+ ):
685
+ fv = FeatureView (_TEST_OPTIMIZED_FV1_PATH )
686
+ fv .read (
687
+ key = ["key1" ],
688
+ connection_options = fs_utils .ConnectionOptions (
689
+ host = "1.1.1.1" ,
690
+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
691
+ ),
692
+ )
693
+ fv .read (
694
+ key = ["key2" ],
695
+ connection_options = fs_utils .ConnectionOptions (
696
+ host = "1.1.1.1" ,
697
+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
698
+ ),
699
+ )
700
+
701
+ # Insecure channel and transport creation should only be done once.
702
+ assert grpc_insecure_channel_mock .call_args_list == [mock .call ("1.1.1.1:10002" )]
703
+ assert transport_mock .call_args_list == [
704
+ mock .call (channel = grpc_insecure_channel_mock .return_value ),
705
+ ]
706
+
707
+
708
+ def test_ffv_optimized_psc_different_client_for_different_connection_options (
709
+ get_psc_optimized_fos_mock ,
710
+ get_optimized_fv_mock ,
711
+ client_mock ,
712
+ transport_mock ,
713
+ grpc_insecure_channel_mock ,
714
+ fetch_feature_values_mock ,
715
+ ):
716
+ # Return two different grpc channels each time insecure channel is called.
717
+ import grpc
718
+
719
+ grpc_chan1 = mock .MagicMock (spec = grpc .Channel )
720
+ grpc_chan2 = mock .MagicMock (spec = grpc .Channel )
721
+ grpc_insecure_channel_mock .side_effect = [grpc_chan1 , grpc_chan2 ]
722
+
723
+ fv = FeatureView (_TEST_OPTIMIZED_FV1_PATH )
724
+ fv .read (
725
+ key = ["key1" ],
726
+ connection_options = fs_utils .ConnectionOptions (
727
+ host = "1.1.1.1" ,
728
+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
729
+ ),
730
+ )
731
+ fv .read (
732
+ key = ["key2" ],
733
+ connection_options = fs_utils .ConnectionOptions (
734
+ host = "1.2.3.4" ,
735
+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
736
+ ),
737
+ )
738
+
739
+ # Insecure channel and transport creation should be done twice - one for each different connection.
740
+ assert grpc_insecure_channel_mock .call_args_list == [
741
+ mock .call ("1.1.1.1:10002" ),
742
+ mock .call ("1.2.3.4:10002" ),
743
+ ]
744
+ assert transport_mock .call_args_list == [
745
+ mock .call (channel = grpc_chan1 ),
746
+ mock .call (channel = grpc_chan2 ),
747
+ ]
748
+
749
+
750
+ def test_ffv_optimized_psc_bad_gapic_client_raises_error (
751
+ get_psc_optimized_fos_mock , get_optimized_fv_mock , utils_client_with_override_mock
752
+ ):
753
+ with pytest .raises (ValueError ) as excinfo :
754
+ FeatureView (_TEST_OPTIMIZED_FV1_PATH ).read (
755
+ key = ["key1" ],
756
+ connection_options = fs_utils .ConnectionOptions (
757
+ host = "1.1.1.1" ,
758
+ transport = fs_utils .ConnectionOptions .InsecureGrpcChannel (),
759
+ ),
760
+ )
761
+
762
+ assert str (excinfo .value ) == (
763
+ f"Unexpected gapic class '{ utils_client_with_override_mock .get_gapic_client_class .return_value } ' used by internal client."
764
+ )
765
+
766
+
431
767
@pytest .mark .parametrize ("output_type" , ["dict" , "proto" ])
432
768
def test_search_nearest_entities (
433
769
get_esf_optimized_fos_mock ,
0 commit comments