#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABCMeta
from abc import abstractmethod
from typing import Any
from typing import List
from typing import Optional
from typing import Union
from pydantic import Field
from pydantic import StrictBytes
from pydantic import StrictStr
from typing_extensions import Annotated
from gs_interactive.api import AdminServiceGraphManagementApi
from gs_interactive.api import AdminServiceJobManagementApi
from gs_interactive.api import AdminServiceProcedureManagementApi
from gs_interactive.api import AdminServiceServiceManagementApi
from gs_interactive.api import GraphServiceEdgeManagementApi
from gs_interactive.api import GraphServiceVertexManagementApi
from gs_interactive.api import QueryServiceApi
from gs_interactive.api import UtilsApi
from gs_interactive.api_client import ApiClient
from gs_interactive.client.generated.results_pb2 import CollectiveResults
from gs_interactive.client.result import Result
from gs_interactive.client.status import Status
from gs_interactive.client.status import StatusCode
from gs_interactive.client.utils import InputFormat
from gs_interactive.client.utils import append_format_byte
from gs_interactive.configuration import Configuration
from gs_interactive.models import CreateGraphRequest
from gs_interactive.models import CreateGraphResponse
from gs_interactive.models import CreateProcedureRequest
from gs_interactive.models import CreateProcedureResponse
from gs_interactive.models import EdgeRequest
from gs_interactive.models import GetGraphResponse
from gs_interactive.models import GetGraphSchemaResponse
from gs_interactive.models import GetGraphStatisticsResponse
from gs_interactive.models import GetProcedureResponse
from gs_interactive.models import JobResponse
from gs_interactive.models import JobStatus
from gs_interactive.models import QueryRequest
from gs_interactive.models import SchemaMapping
from gs_interactive.models import ServiceStatus
from gs_interactive.models import StartServiceRequest
from gs_interactive.models import StopServiceRequest
from gs_interactive.models import UpdateProcedureRequest
from gs_interactive.models import UploadFileResponse
from gs_interactive.models import VertexData
from gs_interactive.models import VertexEdgeRequest
from gs_interactive.models import VertexRequest
class EdgeInterface(metaclass=ABCMeta):
@abstractmethod
def add_edge(
self, graph_id: StrictStr, edge_request: List[EdgeRequest]
) -> Result[str]:
raise NotImplementedError
@abstractmethod
def delete_edge(
self,
graph_id: StrictStr,
src_label: Annotated[
StrictStr, Field(description="The label name of src vertex.")
],
src_primary_key_value: Annotated[
Any, Field(description="The primary key value of src vertex.")
],
dst_label: Annotated[
StrictStr, Field(description="The label name of dst vertex.")
],
dst_primary_key_value: Annotated[
Any, Field(description="The primary key value of dst vertex.")
],
) -> Result[str]:
raise NotImplementedError
@abstractmethod
def get_edge(
self,
graph_id: StrictStr,
edge_label: Annotated[StrictStr, Field(description="The label name of edge.")],
src_label: Annotated[
StrictStr, Field(description="The label name of src vertex.")
],
src_primary_key_value: Annotated[
Any, Field(description="The primary key value of src vertex.")
],
dst_label: Annotated[
StrictStr, Field(description="The label name of dst vertex.")
],
dst_primary_key_value: Annotated[
Any, Field(description="The primary key value of dst vertex.")
],
) -> Result[Union[None, EdgeRequest]]:
raise NotImplementedError
@abstractmethod
def update_edge(
self, graph_id: StrictStr, edge_request: EdgeRequest
) -> Result[str]:
raise NotImplementedError
class VertexInterface(metaclass=ABCMeta):
@abstractmethod
def add_vertex(
self,
graph_id: StrictStr,
vertex_edge_request: VertexEdgeRequest,
) -> Result[StrictStr]:
raise NotImplementedError
@abstractmethod
def delete_vertex(
self,
graph_id: StrictStr,
label: Annotated[StrictStr, Field(description="The label name of vertex.")],
primary_key_value: Annotated[
Any, Field(description="The primary key value of vertex.")
],
) -> Result[str]:
raise NotImplementedError
@abstractmethod
def get_vertex(
self,
graph_id: StrictStr,
label: Annotated[StrictStr, Field(description="The label name of vertex.")],
primary_key_value: Annotated[
Any, Field(description="The primary key value of vertex.")
],
) -> Result[VertexData]:
raise NotImplementedError
@abstractmethod
def update_vertex(
self, graph_id: StrictStr, vertex_request: VertexRequest
) -> Result[str]:
raise NotImplementedError
class GraphInterface(metaclass=ABCMeta):
@abstractmethod
def create_graph(self, graph: CreateGraphRequest) -> Result[CreateGraphResponse]:
raise NotImplementedError
@abstractmethod
def get_graph_schema(
graph_id: Annotated[StrictStr, Field(description="The id of graph to get")],
) -> Result[GetGraphSchemaResponse]:
raise NotImplementedError
@abstractmethod
def get_graph_meta(
graph_id: Annotated[StrictStr, Field(description="The id of graph to get")],
) -> Result[GetGraphResponse]:
raise NotImplementedError
@abstractmethod
def get_graph_statistics(
graph_id: Annotated[StrictStr, Field(description="The id of graph to get")],
) -> Result[GetGraphStatisticsResponse]:
raise NotImplementedError
@abstractmethod
def delete_graph(
graph_id: Annotated[StrictStr, Field(description="The id of graph to delete")],
) -> Result[str]:
raise NotImplementedError
@abstractmethod
def list_graphs(self) -> Result[List[GetGraphResponse]]:
raise NotImplementedError
@abstractmethod
def bulk_loading(
self,
graph_id: Annotated[StrictStr, Field(description="The id of graph to load")],
schema_mapping: SchemaMapping,
) -> Result[JobResponse]:
raise NotImplementedError
class ProcedureInterface(metaclass=ABCMeta):
@abstractmethod
def create_procedure(
self, graph_id: StrictStr, procedure: CreateProcedureRequest
) -> Result[CreateProcedureResponse]:
raise NotImplementedError
@abstractmethod
def delete_procedure(
self, graph_id: StrictStr, procedure_id: StrictStr
) -> Result[str]:
raise NotImplementedError
@abstractmethod
def list_procedures(
self, graph_id: StrictStr
) -> Result[List[GetProcedureResponse]]:
raise NotImplementedError
@abstractmethod
def update_procedure(
self, graph_id: StrictStr, proc_id: StrictStr, procedure: UpdateProcedureRequest
) -> Result[str]:
raise NotImplementedError
@abstractmethod
def get_procedure(
self, graph_id: StrictStr, procedure_id: StrictStr
) -> Result[GetProcedureResponse]:
raise NotImplementedError
@abstractmethod
def call_procedure(
self, graph_id: StrictStr, params: QueryRequest
) -> Result[CollectiveResults]:
raise NotImplementedError
@abstractmethod
def call_procedure_current(self, params: QueryRequest) -> Result[CollectiveResults]:
raise NotImplementedError
@abstractmethod
def call_procedure_raw(self, graph_id: StrictStr, params: bytes) -> Result[str]:
raise NotImplementedError
@abstractmethod
def call_procedure_current_raw(self, params: bytes) -> Result[str]:
raise NotImplementedError
class QueryServiceInterface:
@abstractmethod
def get_service_status(self) -> Result[ServiceStatus]:
raise NotImplementedError
@abstractmethod
def start_service(
self,
start_service_request: Annotated[
Optional[StartServiceRequest],
Field(description="Start service on a specified graph"),
] = None,
) -> Result[str]:
raise NotImplementedError
@abstractmethod
def stop_service(self, graph_id: str) -> Result[str]:
raise NotImplementedError
@abstractmethod
def restart_service(self) -> Result[str]:
raise NotImplementedError
class JobInterface(metaclass=ABCMeta):
@abstractmethod
def get_job(self, job_id: StrictStr) -> Result[JobStatus]:
raise NotImplementedError
@abstractmethod
def list_jobs(self) -> Result[List[JobResponse]]:
raise NotImplementedError
@abstractmethod
def cancel_job(self, job_id: StrictStr) -> Result[str]:
raise NotImplementedError
class UiltsInterface(metaclass=ABCMeta):
@abstractmethod
def upload_file(
self, filestorage: Optional[Union[StrictBytes, StrictStr]]
) -> Result[UploadFileResponse]:
raise NotImplementedError
[docs]class Session(
VertexInterface,
EdgeInterface,
GraphInterface,
ProcedureInterface,
JobInterface,
QueryServiceInterface,
UiltsInterface,
):
pass
[docs]class DefaultSession(Session):
"""
The default session implementation for Interactive SDK.
It provides the implementation of all service APIs.
"""
[docs] def __init__(self, admin_uri: str, stored_proc_uri: str = None):
"""
Construct a new session using the specified admin_uri and stored_proc_uri.
Args:
admin_uri (str): the uri for the admin service.
stored_proc_uri (str, optional): the uri for the stored procedure service.
If not provided,the uri will be read from the service status.
"""
self._admin_uri = admin_uri
self._client = ApiClient(Configuration(host=admin_uri))
self._graph_api = AdminServiceGraphManagementApi(self._client)
self._job_api = AdminServiceJobManagementApi(self._client)
self._procedure_api = AdminServiceProcedureManagementApi(self._client)
self._service_api = AdminServiceServiceManagementApi(self._client)
self._utils_api = UtilsApi(self._client)
if stored_proc_uri is None:
service_status = self.get_service_status()
if not service_status.is_ok():
raise Exception(
"Failed to get service status: ",
service_status.get_status_message(),
)
service_port = service_status.get_value().hqps_port
# replace the port in uri
splitted = admin_uri.split(":")
splitted[-1] = str(service_port)
stored_proc_uri = ":".join(splitted)
self._query_client = ApiClient(Configuration(host=stored_proc_uri))
self._query_api = QueryServiceApi(self._query_client)
self._edge_api = GraphServiceEdgeManagementApi(self._query_client)
self._vertex_api = GraphServiceVertexManagementApi(self._query_client)
def __enter__(self):
self._client.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._client.__exit__(exc_type=exc_type, exc_value=exc_val, traceback=exc_tb)
def admin_uri(self):
return self._admin_uri
# implementations of the methods from the interfaces
[docs] def add_vertex(
self,
graph_id: StrictStr,
vertex_edge_request: VertexEdgeRequest,
) -> Result[StrictStr]:
"""
Add a vertex to the specified graph.
"""
graph_id = self.ensure_param_str("graph_id", graph_id)
try:
api_response = self._vertex_api.add_vertex_with_http_info(
graph_id, vertex_edge_request
)
return Result.from_response(api_response)
except Exception as e:
return Result.from_exception(e)
def delete_vertex(
self,
graph_id: StrictStr,
label: Annotated[StrictStr, Field(description="The label name of vertex.")],
primary_key_value: Annotated[
Any, Field(description="The primary key value of vertex.")
],
) -> Result[str]:
raise NotImplementedError
[docs] def get_vertex(
self,
graph_id: StrictStr,
label: Annotated[StrictStr, Field(description="The label name of vertex.")],
primary_key_value: Annotated[
Any, Field(description="The primary key value of vertex.")
],
) -> Result[VertexData]:
"""
Get a vertex from the specified graph with primary key value.
"""
graph_id = self.ensure_param_str("graph_id", graph_id)
try:
api_response = self._vertex_api.get_vertex_with_http_info(
graph_id, label, primary_key_value
)
return Result.from_response(api_response)
except Exception as e:
return Result.from_exception(e)
[docs] def update_vertex(
self, graph_id: StrictStr, vertex_request: VertexRequest
) -> Result[str]:
"""
Update a vertex in the specified graph.
"""
graph_id = self.ensure_param_str("graph_id", graph_id)
try:
api_response = self._vertex_api.update_vertex_with_http_info(
graph_id, vertex_request
)
return Result.from_response(api_response)
except Exception as e:
return Result.from_exception(e)
[docs] def add_edge(
self, graph_id: StrictStr, edge_request: List[EdgeRequest]
) -> Result[str]:
"""
Add an edge to the specified graph.
"""
graph_id = self.ensure_param_str("graph_id", graph_id)
try:
api_response = self._edge_api.add_edge_with_http_info(
graph_id, edge_request
)
return Result.from_response(api_response)
except Exception as e:
return Result.from_exception(e)
def delete_edge(
self,
graph_id: StrictStr,
src_label: Annotated[
StrictStr, Field(description="The label name of src vertex.")
],
src_primary_key_value: Annotated[
Any, Field(description="The primary key value of src vertex.")
],
dst_label: Annotated[
StrictStr, Field(description="The label name of dst vertex.")
],
dst_primary_key_value: Annotated[
Any, Field(description="The primary key value of dst vertex.")
],
) -> Result[str]:
raise NotImplementedError
[docs] def get_edge(
self,
graph_id: StrictStr,
edge_label: Annotated[StrictStr, Field(description="The label name of edge.")],
src_label: Annotated[
StrictStr, Field(description="The label name of src vertex.")
],
src_primary_key_value: Annotated[
Any, Field(description="The primary key value of src vertex.")
],
dst_label: Annotated[
StrictStr, Field(description="The label name of dst vertex.")
],
dst_primary_key_value: Annotated[
Any, Field(description="The primary key value of dst vertex.")
],
) -> Result[Union[None, EdgeRequest]]:
"""
Get an edge from the specified graph with primary key value.
"""
graph_id = self.ensure_param_str("graph_id", graph_id)
try:
api_response = self._edge_api.get_edge_with_http_info(
graph_id,
edge_label,
src_label,
src_primary_key_value,
dst_label,
dst_primary_key_value,
)
return Result.from_response(api_response)
except Exception as e:
return Result.from_exception(e)
[docs] def update_edge(
self, graph_id: StrictStr, edge_request: EdgeRequest
) -> Result[str]:
"""
Update an edge in the specified graph.
"""
graph_id = self.ensure_param_str("graph_id", graph_id)
try:
api_response = self._edge_api.update_edge_with_http_info(
graph_id, edge_request
)
return Result.from_response(api_response)
except Exception as e:
return Result.from_exception(e)
[docs] def create_graph(self, graph: CreateGraphRequest) -> Result[CreateGraphResponse]:
"""
Create a new graph with the specified graph request.
"""
try:
response = self._graph_api.create_graph_with_http_info(graph)
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def get_graph_schema(
self,
graph_id: Annotated[StrictStr, Field(description="The id of graph to get")],
) -> Result[GetGraphSchemaResponse]:
"""Get the schema of a specified graph.
Parameters:
graph_id (str): The ID of the graph whose schema is to be retrieved.
Returns:
Result[GetGraphSchemaResponse]: The result containing the schema of
the specified graph.
"""
graph_id = self.ensure_param_str("graph_id", graph_id)
try:
response = self._graph_api.get_schema_with_http_info(graph_id)
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def get_graph_statistics(
self,
graph_id: Annotated[StrictStr, Field(description="The id of graph to get")],
) -> Result[GetGraphStatisticsResponse]:
"""
Get the statistics of a specified graph.
"""
graph_id = self.ensure_param_str("graph_id", graph_id)
try:
response = self._graph_api.get_graph_statistic_with_http_info(graph_id)
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def delete_graph(
self,
graph_id: Annotated[StrictStr, Field(description="The id of graph to delete")],
) -> Result[str]:
"""
Delete a graph with the specified graph id.
"""
graph_id = self.ensure_param_str("graph_id", graph_id)
try:
response = self._graph_api.delete_graph_with_http_info(graph_id)
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def list_graphs(self) -> Result[List[GetGraphResponse]]:
"""
List all graphs.
"""
try:
response = self._graph_api.list_graphs_with_http_info()
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def bulk_loading(
self,
graph_id: Annotated[StrictStr, Field(description="The id of graph to load")],
schema_mapping: SchemaMapping,
) -> Result[JobResponse]:
"""
Submit a bulk loading job to the specified graph.
"""
graph_id = self.ensure_param_str("graph_id", graph_id)
# First try to upload the input files if they are specified with a starting @
# return a new schema_mapping with the uploaded files
upload_res = self.try_upload_files(schema_mapping)
if not upload_res.is_ok():
return upload_res
schema_mapping = upload_res.get_value()
print("new schema_mapping: ", schema_mapping)
try:
response = self._graph_api.create_dataloading_job_with_http_info(
graph_id, schema_mapping
)
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def create_procedure(
self, graph_id: StrictStr, procedure: CreateProcedureRequest
) -> Result[CreateProcedureResponse]:
"""
Create a new procedure in the specified graph.
"""
graph_id = self.ensure_param_str("graph_id", graph_id)
try:
response = self._procedure_api.create_procedure_with_http_info(
graph_id, procedure
)
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def delete_procedure(
self, graph_id: StrictStr, procedure_id: StrictStr
) -> Result[str]:
"""
Delete a procedure in the specified graph.
"""
graph_id = self.ensure_param_str("graph_id", graph_id)
procedure_id = self.ensure_param_str("procedure_id", procedure_id)
try:
response = self._procedure_api.delete_procedure_with_http_info(
graph_id, procedure_id
)
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def list_procedures(
self, graph_id: StrictStr
) -> Result[List[GetProcedureResponse]]:
"""
List all procedures in the specified graph.
"""
graph_id = self.ensure_param_str("graph_id", graph_id)
try:
response = self._procedure_api.list_procedures_with_http_info(graph_id)
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def update_procedure(
self, graph_id: StrictStr, proc_id: StrictStr, procedure: UpdateProcedureRequest
) -> Result[str]:
"""
Update a procedure in the specified graph.
"""
graph_id = self.ensure_param_str("graph_id", graph_id)
try:
response = self._procedure_api.update_procedure_with_http_info(
graph_id, proc_id, procedure
)
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def get_procedure(
self, graph_id: StrictStr, procedure_id: StrictStr
) -> Result[GetProcedureResponse]:
"""
Get a procedure in the specified graph.
"""
graph_id = self.ensure_param_str("graph_id", graph_id)
try:
response = self._procedure_api.get_procedure_with_http_info(
graph_id, procedure_id
)
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def call_procedure(
self, graph_id: StrictStr, params: QueryRequest
) -> Result[CollectiveResults]:
"""
Call a procedure in the specified graph.
"""
graph_id = self.ensure_param_str("graph_id", graph_id)
try:
# gs_interactive currently support four type of inputformat,
# see flex/engines/graph_db/graph_db_session.h
# Here we add byte of value 1 to denote the input format is in json format
response = self._query_api.call_proc_with_http_info(
graph_id=graph_id,
body=append_format_byte(
params.to_json().encode(), InputFormat.CYPHER_JSON
),
)
result = CollectiveResults()
if response.status_code == 200:
result.ParseFromString(response.data)
return Result.ok(result)
else:
return Result(Status.from_response(response), result)
except Exception as e:
return Result.from_exception(e)
[docs] def call_procedure_current(self, params: QueryRequest) -> Result[CollectiveResults]:
"""
Call a procedure in the current graph.
"""
try:
# gs_interactive currently support four type of inputformat,
# see flex/engines/graph_db/graph_db_session.h
# Here we add byte of value 1 to denote the input format is in json format
response = self._query_api.call_proc_current_with_http_info(
body=append_format_byte(
params.to_json().encode(), InputFormat.CYPHER_JSON
)
)
result = CollectiveResults()
if response.status_code == 200:
result.ParseFromString(response.data)
return Result.ok(result)
else:
return Result(Status.from_response(response), result)
except Exception as e:
return Result.from_exception(e)
[docs] def call_procedure_raw(self, graph_id: StrictStr, params: bytes) -> Result[str]:
"""
Call a procedure in the specified graph with raw bytes.
"""
graph_id = self.ensure_param_str("graph_id", graph_id)
try:
# gs_interactive currently support four type of inputformat,
# see flex/engines/graph_db/graph_db_session.h
# Here we add byte of value 1 to denote the input format is in encoder/decoder format
response = self._query_api.call_proc_with_http_info(
graph_id=graph_id,
body=append_format_byte(params, InputFormat.CPP_ENCODER),
)
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def call_procedure_current_raw(self, params: bytes) -> Result[str]:
"""
Call a procedure in the current graph with raw bytes.
"""
try:
# gs_interactive currently support four type of inputformat,
# see flex/engines/graph_db/graph_db_session.h
# Here we add byte of value 1 to denote the input format is in encoder/decoder format
response = self._query_api.call_proc_current_with_http_info(
body=append_format_byte(params, InputFormat.CPP_ENCODER)
)
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def get_service_status(self) -> Result[ServiceStatus]:
"""
Get the status of the service.
"""
try:
response = self._service_api.get_service_status_with_http_info()
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def start_service(
self,
start_service_request: Annotated[
Optional[StartServiceRequest],
Field(description="Start service on a specified graph"),
] = None,
) -> Result[str]:
"""
Start the service on a specified graph.
"""
try:
response = self._service_api.start_service_with_http_info(
start_service_request
)
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def stop_service(self, graph_id: str = None) -> Result[str]:
"""
Stop the service.
"""
try:
req = StopServiceRequest()
if graph_id:
req.graph_id = graph_id
response = self._service_api.stop_service_with_http_info(req)
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def restart_service(self) -> Result[str]:
"""
Restart the service.
"""
try:
response = self._service_api.restart_service_with_http_info()
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def get_job(self, job_id: StrictStr) -> Result[JobStatus]:
"""
Get the status of a job with the specified job id.
"""
job_id = self.ensure_param_str("job_id", job_id)
try:
response = self._job_api.get_job_by_id_with_http_info(job_id)
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def list_jobs(self) -> Result[List[JobResponse]]:
"""
List all jobs.
"""
try:
response = self._job_api.list_jobs_with_http_info()
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def cancel_job(self, job_id: StrictStr) -> Result[str]:
"""
Cancel a job with the specified job id.
"""
job_id = self.ensure_param_str("job_id", job_id)
try:
response = self._job_api.delete_job_by_id_with_http_info(job_id)
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)
[docs] def upload_file(
self, filestorage: Optional[Union[StrictBytes, StrictStr]]
) -> Result[UploadFileResponse]:
"""
Upload a file to the server.
"""
try:
print("uploading file: ", filestorage)
response = self._utils_api.upload_file_with_http_info(filestorage)
print("response: ", response)
if response.status_code == 200:
# the response is the path of the uploaded file on server.
return Result.from_response(response)
else:
print("Failed to upload file: ", input)
return Result.from_response(response)
except Exception as e:
print("got exception: ", e)
return Result.from_exception(e)
def trim_path(self, path: str) -> str:
return path[1:] if path.startswith("@") else path
def preprocess_inputs(
self, location: str, inputs: List[str], schema_mapping: SchemaMapping
):
root_dir_marked_with_at = False
if location and location.startswith("@"):
root_dir_marked_with_at = True
new_inputs = []
for i, input in enumerate(inputs):
# First check whether input is valid
if location and not root_dir_marked_with_at:
if input.startswith("@"):
print(
"Root location given without @, but the input file starts with @"
+ input
+ ", index: "
+ str(i),
)
return Result.error(
Status(
StatusCode.BAD_REQUEST,
"Root location given without @, but the input file starts with @"
+ input
+ ", index: "
+ str(i),
),
new_inputs,
)
if location:
new_inputs.append(location + "/" + self.trim_path(input))
else:
new_inputs.append(input)
return Result.ok(new_inputs)
def check_file_mixup(self, schema_mapping: SchemaMapping) -> Result[SchemaMapping]:
location = None
if schema_mapping.loading_config and schema_mapping.loading_config.data_source:
if schema_mapping.loading_config.data_source.scheme != "file":
print("Only check mixup for file scheme")
return Result.ok(schema_mapping)
location = schema_mapping.loading_config.data_source.location
extracted_files = []
if schema_mapping.vertex_mappings:
for vertex_mapping in schema_mapping.vertex_mappings:
if vertex_mapping.inputs:
preprocess_result = self.preprocess_inputs(
location, vertex_mapping.inputs, schema_mapping
)
if not preprocess_result.is_ok():
return Result.error(preprocess_result.status, schema_mapping)
vertex_mapping.inputs = preprocess_result.get_value()
extracted_files.extend(vertex_mapping.inputs)
if schema_mapping.edge_mappings:
for edge_mapping in schema_mapping.edge_mappings:
if edge_mapping.inputs:
preprocess_result = self.preprocess_inputs(
location, edge_mapping.inputs, schema_mapping
)
if not preprocess_result.is_ok():
return Result.error(preprocess_result.status, schema_mapping)
edge_mapping.inputs = preprocess_result.get_value()
extracted_files.extend(edge_mapping.inputs)
if extracted_files:
# count the number of files start with @
count = 0
for file in extracted_files:
if file.startswith("@"):
count += 1
if count == 0:
print("No file to upload")
return Result.ok(schema_mapping)
elif count != len(extracted_files):
print("Can not mix uploading file and not uploading file")
return Result.error("Can not mix uploading file and not uploading file")
return Result.ok(schema_mapping)
[docs] def try_upload_files(self, schema_mapping: SchemaMapping) -> Result[SchemaMapping]:
"""
Try to upload the input files if they are specified with a starting @
for input files in schema_mapping. Replace the path to the uploaded file with the
path returned from the server.
The @ can be added to the beginning of data_source.location
in schema_mapping.loading_config,or added to each file in vertex_mappings
and edge_mappings.
1. location: @/path/to/dir
inputs:
- @/path/to/file1
- @/path/to/file2
2. location: /path/to/dir
inputs:
- @/path/to/file1
- @/path/to/file2
3. location: @/path/to/dir
inputs:
- /path/to/file1
- /path/to/file2
4. location: /path/to/dir
inputs:
- /path/to/file1
- /path/to/file2
4. location: None
inputs:
- @/path/to/file1
- @/path/to/file2
Among the above 4 cases, only the 1, 3, 5 case are valid,
for 2,4 the file will not be uploaded
"""
check_mixup_res = self.check_file_mixup(schema_mapping)
if not check_mixup_res.is_ok():
return check_mixup_res
schema_mapping = check_mixup_res.get_value()
# now try upload the replace inplace
print("after check_mixup_res: ")
upload_res = self.upload_and_replace_input_inplace(schema_mapping)
if not upload_res.is_ok():
return upload_res
print("new schema_mapping: ", upload_res.get_value())
return Result.ok(upload_res.get_value())
[docs] def ensure_param_str(self, param_name: str, param):
"""
Ensure the param is a string, otherwise raise an exception
"""
if not isinstance(param, str):
# User may input the graph_id as int, convert it to string
if isinstance(param, int):
return str(param)
raise Exception(
"param should be a string, param_name: "
+ param_name
+ ", param: "
+ str(param)
)
return param