PySpark MLでカスタムトランスフォーマーを作成する のコメントセクションで同じ議論を見つけましたが、明確な答えはありません。これに対応する未解決のJIRAもあります: https://issues.Apache.org/jira/browse/SPARK-17025 。
Pythonで記述されたカスタムトランスフォーマーを保存するためのPysparkMLパイプラインによって提供されるオプションがない場合、それを実行するための他のオプションは何ですか?互換性のあるpythonオブジェクトを返す)Javaクラスに_to_Javaメソッドを実装するにはどうすればよいですか?
Spark 2.3.0の時点で、much、muchこれを行うためのより良い方法。
DefaultParamsWritable
および DefaultParamsReadable
を拡張するだけで、クラスには自動的にwrite
およびread
メソッドが含まれます。パラメータを保存し、PipelineModel
シリアル化システムによって使用されます。
ドキュメントはあまり明確ではなく、これが逆シリアル化が機能する方法であることを理解するために、ソースを少し読む必要がありました。
PipelineModel.read
_PipelineModelReader
をインスタンス化しますPipelineModelReader
はメタデータをロードし、言語が_'Python'
_であるかどうかを確認します。そうでない場合は、通常のJavaMLReader
が使用されます(これらの回答のほとんどは何のために設計されていますか)PipelineSharedReadWrite
が使用され、 _DefaultParamsReader.loadParamsInstance
_ が呼び出されます。loadParamsInstance
は、保存されたメタデータからclass
を検索します。そのクラスをインスタンス化し、そのクラスで.load(path)
を呼び出します。 DefaultParamsReader
を拡張して、 _DefaultParamsReader.load
_ メソッドを自動的に取得できます。実装する必要のある特殊な逆シリアル化ロジックがある場合は、そのload
メソッドを出発点と見なします。
反対側:
PipelineModel.write
_ すべてのステージがJava(implement JavaMLWritable
)であるかどうかを確認します。そうである場合、通常のJavaMLWriter
は使用済み(これらの回答のほとんどは何のために設計されているか)PipelineWriter
が使用され、すべてのステージがMLWritable
を実装していることを確認し、_PipelineSharedReadWrite.saveImpl
_を呼び出します。PipelineSharedReadWrite.saveImpl
_は、各ステージで.write().save(path)
を呼び出します。DefaultParamsWriter
を拡張して、クラスとパラメーターのメタデータを正しい形式で保存する _DefaultParamsWritable.write
_ メソッドを取得できます。実装する必要のあるカスタムシリアル化ロジックがある場合は、それと DefaultParamsWriter
を出発点として検討します。
さて、最後に、Paramsを拡張する非常に単純なトランスフォーマーがあり、すべてのパラメーターは通常のParams形式で格納されます。
_from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasOutputCols, Param, Params
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.sql.functions import lit # for the dummy _transform
class SetValueTransformer(
Transformer, HasOutputCols, DefaultParamsReadable, DefaultParamsWritable,
):
value = Param(
Params._dummy(),
"value",
"value to fill",
)
@keyword_only
def __init__(self, outputCols=None, value=0.0):
super(SetValueTransformer, self).__init__()
self._setDefault(value=0.0)
kwargs = self._input_kwargs
self._set(**kwargs)
@keyword_only
def setParams(self, outputCols=None, value=0.0):
"""
setParams(self, outputCols=None, value=0.0)
Sets params for this SetValueTransformer.
"""
kwargs = self._input_kwargs
return self._set(**kwargs)
def setValue(self, value):
"""
Sets the value of :py:attr:`value`.
"""
return self._set(value=value)
def getValue(self):
"""
Gets the value of :py:attr:`value` or its default value.
"""
return self.getOrDefault(self.value)
def _transform(self, dataset):
for col in self.getOutputCols():
dataset = dataset.withColumn(col, lit(self.getValue()))
return dataset
_
これで使用できます。
_from pyspark.ml import Pipeline
svt = SetValueTransformer(outputCols=["a", "b"], value=123.0)
p = Pipeline(stages=[svt])
df = sc.parallelize([(1, None), (2, 1.0), (3, 0.5)]).toDF(["key", "value"])
pm = p.fit(df)
pm.transform(df).show()
pm.write().overwrite().save('/tmp/example_pyspark_pipeline')
pm2 = PipelineModel.load('/tmp/example_pyspark_pipeline')
print('matches?', pm2.stages[0].extractParamMap() == pm.stages[0].extractParamMap())
pm2.transform(df).show()
_
結果:
_+---+-----+-----+-----+
|key|value| a| b|
+---+-----+-----+-----+
| 1| null|123.0|123.0|
| 2| 1.0|123.0|123.0|
| 3| 0.5|123.0|123.0|
+---+-----+-----+-----+
matches? True
+---+-----+-----+-----+
|key|value| a| b|
+---+-----+-----+-----+
| 1| null|123.0|123.0|
| 2| 1.0|123.0|123.0|
| 3| 0.5|123.0|123.0|
+---+-----+-----+-----+
_
これが最善のアプローチかどうかはわかりませんが、Pysparkで作成したカスタムのEstimator、Transformer、Modelを保存し、PipelineAPIでの使用を永続的にサポートする機能も必要です。カスタムPysparkEstimator、Transformers、およびModelsは、Pipeline APIで作成および使用できますが、保存することはできません。これは、モデルのトレーニングにイベント予測サイクルよりも時間がかかる場合に、本番環境で問題を引き起こします。
一般に、Pyspark Estimator、Transformers、Modelsは、JavaまたはScalaに相当するものの単なるラッパーであり、Pysparkラッパーはpy4jを介してJavaとの間でパラメーターをマーシャリングするだけです。その後、モデルの永続化はJava側で実行されます。この現在の構造により、カスタムPyspark Estimator、Transformers、Modelsはpythonの世界にのみ存在するように制限されます。
以前の試みでは、Pickle/dillシリアル化を使用して単一のPysparkモデルを保存することができました。これはうまく機能しましたが、PipelineAPI内からそのようなものを保存またはロードバックすることはできませんでした。しかし、別のSOの投稿で指摘されたように、私はOneVsRest分類子に誘導され、_to_Javaメソッドと_from_Javaメソッドを調べました。彼らはPyspark側ですべての重労働を行います。私が考えた後、ピクルスダンプをすでに作成されサポートされている保存可能なJavaオブジェクトに保存する方法があれば、PipelineAPIを使用してカスタムPysparkEstimator、Transformer、およびModelを保存できるはずです。
そのために、StopWordsRemoverは、文字列のリストである属性stopwordsを持っているため、ハイジャックの理想的なオブジェクトであることがわかりました。 dill.dumpsメソッドは、オブジェクトのピクルス表現を文字列として返します。計画は、文字列をリストに変換してから、StopWordsRemoverのstopwordsパラメーターをこのリストに設定することでした。リスト文字列ですが、一部の文字がJavaオブジェクトにマーシャリングされないことがわかりました。したがって、文字は整数に変換され、次に整数は文字列に変換されます。パイプラインは私のpythonクラスの_to_Javaメソッドを忠実に呼び出すため、これはすべて単一のインスタンスを保存する場合や、パイプライン内に保存する場合に最適です(これはまだPyspark側にあるため機能します)。しかし、JavaからPysparkに戻ることは、PipelineAPIにはありませんでした。
pythonオブジェクトをStopWordsRemoverインスタンスで非表示にしているため、Pysparkに戻ったとき、パイプラインは非表示のクラスオブジェクトについて何も認識せず、StopWordsRemoverインスタンスがあることのみを認識します。理想的には、PipelineとPipelineModelをサブクラス化するのは素晴らしいことですが、残念ながら、これによりPythonオブジェクトのシリアル化を試みることに戻ります。これに対抗するために、PipelineまたはPipelineModelを受け取り、ステージをスキャンして、ストップワードリストでコード化されたIDを探します(これはpythonオブジェクトのピクルスバイトにすぎないことを思い出してください)。リストをインスタンスにアンラップし、元のステージに保存します。以下は、これがすべてどのように機能するかを示すコードです。
カスタムPysparkEstimator、Transformer、およびModelの場合は、Identizable、PysparkReaderWriter、MLReadable、MLWritableから継承するだけです。次に、PipelineとPipelineModelをロードするときに、PysparkPipelineWrapper.unwrap(pipeline)を通過させます。
このメソッドは、JavaまたはScalaでのPysparkコードの使用には対応していませんが、少なくともカスタムPyspark Estimator、Transformers、Modelsを保存およびロードし、PipelineAPIを操作できます。
import dill
from pyspark.ml import Transformer, Pipeline, PipelineModel
from pyspark.ml.param import Param, Params
from pyspark.ml.util import Identifiable, MLReadable, MLWritable, JavaMLReader, JavaMLWriter
from pyspark.ml.feature import StopWordsRemover
from pyspark.ml.wrapper import JavaParams
from pyspark.context import SparkContext
from pyspark.sql import Row
class PysparkObjId(object):
"""
A class to specify constants used to idenify and setup python
Estimators, Transformers and Models so they can be serialized on there
own and from within a Pipline or PipelineModel.
"""
def __init__(self):
super(PysparkObjId, self).__init__()
@staticmethod
def _getPyObjId():
return '4c1740b00d3c4ff6806a1402321572cb'
@staticmethod
def _getCarrierClass(javaName=False):
return 'org.Apache.spark.ml.feature.StopWordsRemover' if javaName else StopWordsRemover
class PysparkPipelineWrapper(object):
"""
A class to facilitate converting the stages of a Pipeline or PipelineModel
that were saved from PysparkReaderWriter.
"""
def __init__(self):
super(PysparkPipelineWrapper, self).__init__()
@staticmethod
def unwrap(pipeline):
if not (isinstance(pipeline, Pipeline) or isinstance(pipeline, PipelineModel)):
raise TypeError("Cannot recognize a pipeline of type %s." % type(pipeline))
stages = pipeline.getStages() if isinstance(pipeline, Pipeline) else pipeline.stages
for i, stage in enumerate(stages):
if (isinstance(stage, Pipeline) or isinstance(stage, PipelineModel)):
stages[i] = PysparkPipelineWrapper.unwrap(stage)
if isinstance(stage, PysparkObjId._getCarrierClass()) and stage.getStopWords()[-1] == PysparkObjId._getPyObjId():
swords = stage.getStopWords()[:-1] # strip the id
lst = [chr(int(d)) for d in swords]
dmp = ''.join(lst)
py_obj = dill.loads(dmp)
stages[i] = py_obj
if isinstance(pipeline, Pipeline):
pipeline.setStages(stages)
else:
pipeline.stages = stages
return pipeline
class PysparkReaderWriter(object):
"""
A mixin class so custom pyspark Estimators, Transformers and Models may
support saving and loading directly or be saved within a Pipline or PipelineModel.
"""
def __init__(self):
super(PysparkReaderWriter, self).__init__()
def write(self):
"""Returns an MLWriter instance for this ML instance."""
return JavaMLWriter(self)
@classmethod
def read(cls):
"""Returns an MLReader instance for our clarrier class."""
return JavaMLReader(PysparkObjId._getCarrierClass())
@classmethod
def load(cls, path):
"""Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
swr_Java_obj = cls.read().load(path)
return cls._from_Java(swr_Java_obj)
@classmethod
def _from_Java(cls, Java_obj):
"""
Get the dumby the stopwords that are the characters of the dills dump plus our guid
and convert, via dill, back to our python instance.
"""
swords = Java_obj.getStopWords()[:-1] # strip the id
lst = [chr(int(d)) for d in swords] # convert from string integer list to bytes
dmp = ''.join(lst)
py_obj = dill.loads(dmp)
return py_obj
def _to_Java(self):
"""
Convert this instance to a dill dump, then to a list of strings with the unicode integer values of each character.
Use this list as a set of dumby stopwords and store in a StopWordsRemover instance
:return: Java object equivalent to this instance.
"""
dmp = dill.dumps(self)
pylist = [str(ord(d)) for d in dmp] # convert byes to string integer list
pylist.append(PysparkObjId._getPyObjId()) # add our id so PysparkPipelineWrapper can id us.
sc = SparkContext._active_spark_context
Java_class = sc._gateway.jvm.Java.lang.String
Java_array = sc._gateway.new_array(Java_class, len(pylist))
for i in xrange(len(pylist)):
Java_array[i] = pylist[i]
_Java_obj = JavaParams._new_Java_obj(PysparkObjId._getCarrierClass(javaName=True), self.uid)
_Java_obj.setStopWords(Java_array)
return _Java_obj
class HasFake(Params):
def __init__(self):
super(HasFake, self).__init__()
self.fake = Param(self, "fake", "fake param")
def getFake(self):
return self.getOrDefault(self.fake)
class MockTransformer(Transformer, HasFake, Identifiable):
def __init__(self):
super(MockTransformer, self).__init__()
self.dataset_count = 0
def _transform(self, dataset):
self.dataset_count = dataset.count()
return dataset
class MyTransformer(MockTransformer, Identifiable, PysparkReaderWriter, MLReadable, MLWritable):
def __init__(self):
super(MyTransformer, self).__init__()
def make_a_dataframe(sc):
df = sc.parallelize([Row(name='Alice', age=5, height=80), Row(name='Alice', age=5, height=80), Row(name='Alice', age=10, height=80)]).toDF()
return df
def test1():
trA = MyTransformer()
trA.dataset_count = 999
print trA.dataset_count
trA.save('test.trans')
trB = MyTransformer.load('test.trans')
print trB.dataset_count
def test2():
trA = MyTransformer()
pipeA = Pipeline(stages=[trA])
print type(pipeA)
pipeA.save('testA.pipe')
pipeAA = PysparkPipelineWrapper.unwrap(Pipeline.load('testA.pipe'))
stagesAA = pipeAA.getStages()
trAA = stagesAA[0]
print trAA.dataset_count
def test3():
dfA = make_a_dataframe(sc)
trA = MyTransformer()
pipeA = Pipeline(stages=[trA]).fit(dfA)
print type(pipeA)
pipeA.save('testB.pipe')
pipeAA = PysparkPipelineWrapper.unwrap(PipelineModel.load('testB.pipe'))
stagesAA = pipeAA.stages
trAA = stagesAA[0]
print trAA.dataset_count
dfB = pipeAA.transform(dfA)
dfB.show()
Python 2 on Spark 2.2.0;ピクルスエラーが発生し続けました。いくつかの盲目の路地を通過した後、@ dmbakerの独創的なソリューションを機能させることができませんでした。パラメータ値を文字列として_StopWordsRemover's
_ストップワードに直接読み書きするように彼の(彼女の?)アイデアを変更することによる実用的な解決策。
独自の推定量または変換器を保存してロードする場合に必要な基本クラスは次のとおりです。
_from pyspark import SparkContext
from pyspark.ml.feature import StopWordsRemover
from pyspark.ml.util import Identifiable, MLWritable, JavaMLWriter, MLReadable, JavaMLReader
from pyspark.ml.wrapper import JavaWrapper, JavaParams
class PysparkReaderWriter(Identifiable, MLReadable, MLWritable):
"""
A base class for custom pyspark Estimators and Models to support saving and loading directly
or within a Pipeline or PipelineModel.
"""
def __init__(self):
super(PysparkReaderWriter, self).__init__()
@staticmethod
def _getPyObjIdPrefix():
return "_ThisIsReallyA_"
@classmethod
def _getPyObjId(cls):
return PysparkReaderWriter._getPyObjIdPrefix() + cls.__name__
def getParamsAsListOfStrings(self):
raise NotImplementedError("PysparkReaderWriter.getParamsAsListOfStrings() not implemented for instance: %r" % self)
def write(self):
"""Returns an MLWriter instance for this ML instance."""
return JavaMLWriter(self)
def _to_Java(self):
# Convert all our parameters to strings:
paramValuesAsStrings = self.getParamsAsListOfStrings()
# Append our own type-specific id so PysparkPipelineLoader can detect this algorithm when unwrapping us.
paramValuesAsStrings.append(self._getPyObjId())
# Convert the parameter values to a Java array:
sc = SparkContext._active_spark_context
Java_array = JavaWrapper._new_Java_array(paramValuesAsStrings, sc._gateway.jvm.Java.lang.String)
# Create a Java (Scala) StopWordsRemover and give it the parameters as its stop words.
_Java_obj = JavaParams._new_Java_obj("org.Apache.spark.ml.feature.StopWordsRemover", self.uid)
_Java_obj.setStopWords(Java_array)
return _Java_obj
@classmethod
def _from_Java(cls, Java_obj):
# Get the stop words, ignoring the id at the end:
stopWords = Java_obj.getStopWords()[:-1]
return cls.createAndInitialisePyObj(stopWords)
@classmethod
def createAndInitialisePyObj(cls, paramsAsListOfStrings):
raise NotImplementedError("PysparkReaderWriter.createAndInitialisePyObj() not implemented for type: %r" % cls)
@classmethod
def read(cls):
"""Returns an MLReader instance for our clarrier class."""
return JavaMLReader(StopWordsRemover)
@classmethod
def load(cls, path):
"""Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
swr_Java_obj = cls.read().load(path)
return cls._from_Java(swr_Java_obj)
_
次に、独自のpysparkアルゴリズムがPysparkReaderWriter
から継承し、パラメータを文字列のリストに保存するgetParamsAsListOfStrings()
メソッドをオーバーライドする必要があります。アルゴリズムは、文字列のリストをパラメータに変換するためのcreateAndInitialisePyObj()
メソッドもオーバーライドする必要があります。舞台裏では、パラメータはStopWordsRemover
によって使用されるストップワードとの間で変換されます。
異なるタイプの3つのパラメーターを持つ推定量の例:
_from pyspark.ml.param.shared import Param, Params, TypeConverters
from pyspark.ml.base import Estimator
class MyEstimator(Estimator, PysparkReaderWriter):
def __init__(self):
super(MyEstimator, self).__init__()
# 3 sample parameters, deliberately of different types:
stringParam = Param(Params._dummy(), "stringParam", "A dummy string parameter", typeConverter=TypeConverters.toString)
def setStringParam(self, value):
return self._set(stringParam=value)
def getStringParam(self):
return self.getOrDefault(self.stringParam)
listOfStringsParam = Param(Params._dummy(), "listOfStringsParam", "A dummy list of strings.", typeConverter=TypeConverters.toListString)
def setListOfStringsParam(self, value):
return self._set(listOfStringsParam=value)
def getListOfStringsParam(self):
return self.getOrDefault(self.listOfStringsParam)
intParam = Param(Params._dummy(), "intParam", "A dummy int parameter.", typeConverter=TypeConverters.toInt)
def setIntParam(self, value):
return self._set(intParam=value)
def getIntParam(self):
return self.getOrDefault(self.intParam)
def _fit(self, dataset):
model = MyModel()
# Just some changes to verify we can modify the model (and also it's something we can expect to see when restoring it later):
model.setAnotherStringParam(self.getStringParam() + " World!")
model.setAnotherListOfStringsParam(self.getListOfStringsParam() + ["E", "F"])
model.setAnotherIntParam(self.getIntParam() + 10)
return model
def getParamsAsListOfStrings(self):
paramValuesAsStrings = []
paramValuesAsStrings.append(self.getStringParam()) # Parameter is already a string
paramValuesAsStrings.append(','.join(self.getListOfStringsParam())) # ...convert from a list of strings
paramValuesAsStrings.append(str(self.getIntParam())) # ...convert from an int
return paramValuesAsStrings
@classmethod
def createAndInitialisePyObj(cls, paramsAsListOfStrings):
# Convert back into our parameters. Make sure you do this in the same order you saved them!
py_obj = cls()
py_obj.setStringParam(paramsAsListOfStrings[0])
py_obj.setListOfStringsParam(paramsAsListOfStrings[1].split(","))
py_obj.setIntParam(int(paramsAsListOfStrings[2]))
return py_obj
_
3つの異なるパラメーターを持つモデルの例(これもトランスフォーマー):
_from pyspark.ml.base import Model
class MyModel(Model, PysparkReaderWriter):
def __init__(self):
super(MyModel, self).__init__()
# 3 sample parameters, deliberately of different types:
anotherStringParam = Param(Params._dummy(), "anotherStringParam", "A dummy string parameter", typeConverter=TypeConverters.toString)
def setAnotherStringParam(self, value):
return self._set(anotherStringParam=value)
def getAnotherStringParam(self):
return self.getOrDefault(self.anotherStringParam)
anotherListOfStringsParam = Param(Params._dummy(), "anotherListOfStringsParam", "A dummy list of strings.", typeConverter=TypeConverters.toListString)
def setAnotherListOfStringsParam(self, value):
return self._set(anotherListOfStringsParam=value)
def getAnotherListOfStringsParam(self):
return self.getOrDefault(self.anotherListOfStringsParam)
anotherIntParam = Param(Params._dummy(), "anotherIntParam", "A dummy int parameter.", typeConverter=TypeConverters.toInt)
def setAnotherIntParam(self, value):
return self._set(anotherIntParam=value)
def getAnotherIntParam(self):
return self.getOrDefault(self.anotherIntParam)
def _transform(self, dataset):
# Dummy transform code:
return dataset.withColumn('age2', dataset.age + self.getAnotherIntParam())
def getParamsAsListOfStrings(self):
paramValuesAsStrings = []
paramValuesAsStrings.append(self.getAnotherStringParam()) # Parameter is already a string
paramValuesAsStrings.append(','.join(self.getAnotherListOfStringsParam())) # ...convert from a list of strings
paramValuesAsStrings.append(str(self.getAnotherIntParam())) # ...convert from an int
return paramValuesAsStrings
@classmethod
def createAndInitialisePyObj(cls, paramsAsListOfStrings):
# Convert back into our parameters. Make sure you do this in the same order you saved them!
py_obj = cls()
py_obj.setAnotherStringParam(paramsAsListOfStrings[0])
py_obj.setAnotherListOfStringsParam(paramsAsListOfStrings[1].split(","))
py_obj.setAnotherIntParam(int(paramsAsListOfStrings[2]))
return py_obj
_
以下は、モデルを保存およびロードする方法を示すサンプルテストケースです。推定量についても同様なので、簡潔にするために省略します。
_def createAModel():
m = MyModel()
m.setAnotherStringParam("Boo!")
m.setAnotherListOfStringsParam(["P", "Q", "R"])
m.setAnotherIntParam(77)
return m
def testSaveLoadModel():
modA = createAModel()
print(modA.explainParams())
savePath = "/whatever/path/you/want"
#modA.save(savePath) # Can't overwrite, so...
modA.write().overwrite().save(savePath)
modB = MyModel.load(savePath)
print(modB.explainParams())
testSaveLoadModel()
_
出力:
_anotherIntParam: A dummy int parameter. (current: 77)
anotherListOfStringsParam: A dummy list of strings. (current: ['P', 'Q', 'R'])
anotherStringParam: A dummy string parameter (current: Boo!)
anotherIntParam: A dummy int parameter. (current: 77)
anotherListOfStringsParam: A dummy list of strings. (current: [u'P', u'Q', u'R'])
anotherStringParam: A dummy string parameter (current: Boo!)
_
パラメータがUnicode文字列として戻ってきたことに注目してください。これは、_transform()
(または推定量の場合は_fit()
)で実装する基礎となるアルゴリズムに違いをもたらす場合とそうでない場合があります。したがって、これに注意してください。
最後に、舞台裏のScalaアルゴリズムは実際にはStopWordsRemover
であるため、ディスクからPipeline
またはPipelineModel
をロードするときに、それを独自のクラスにアンラップする必要があります。これがこのアンラッピングを行うユーティリティクラス:
_from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.feature import StopWordsRemover
class PysparkPipelineLoader(object):
"""
A class to facilitate converting the stages of a Pipeline or PipelineModel
that were saved from PysparkReaderWriter.
"""
def __init__(self):
super(PysparkPipelineLoader, self).__init__()
@staticmethod
def unwrap(thingToUnwrap, customClassList):
if not (isinstance(thingToUnwrap, Pipeline) or isinstance(thingToUnwrap, PipelineModel)):
raise TypeError("Cannot recognize an object of type %s." % type(thingToUnwrap))
stages = thingToUnwrap.getStages() if isinstance(thingToUnwrap, Pipeline) else thingToUnwrap.stages
for i, stage in enumerate(stages):
if (isinstance(stage, Pipeline) or isinstance(stage, PipelineModel)):
stages[i] = PysparkPipelineLoader.unwrap(stage)
if isinstance(stage, StopWordsRemover) and stage.getStopWords()[-1].startswith(PysparkReaderWriter._getPyObjIdPrefix()):
lastWord = stage.getStopWords()[-1]
className = lastWord[len(PysparkReaderWriter._getPyObjIdPrefix()):]
stopWords = stage.getStopWords()[:-1] # Strip the id
# Create and initialise the appropriate class:
py_obj = None
for clazz in customClassList:
if clazz.__name__ == className:
py_obj = clazz.createAndInitialisePyObj(stopWords)
if py_obj is None:
raise TypeError("I don't know how to create an instance of type: %s" % className)
stages[i] = py_obj
if isinstance(thingToUnwrap, Pipeline):
thingToUnwrap.setStages(stages)
else:
# PipelineModel
thingToUnwrap.stages = stages
return thingToUnwrap
_
パイプラインの保存と読み込みをテストします。
_def testSaveAndLoadUnfittedPipeline():
estA = createAnEstimator()
#print(estA.explainParams())
pipelineA = Pipeline(stages=[estA])
savePath = "/whatever/path/you/want"
#pipelineA.save(savePath) # Can't overwrite, so...
pipelineA.write().overwrite().save(savePath)
pipelineReloaded = PysparkPipelineLoader.unwrap(Pipeline.load(savePath), [MyEstimator])
estB = pipelineReloaded.getStages()[0]
print(estB.explainParams())
testSaveAndLoadUnfittedPipeline()
_
出力:
_intParam: A dummy int parameter. (current: 42)
listOfStringsParam: A dummy list of strings. (current: [u'A', u'B', u'C', u'D'])
stringParam: A dummy string parameter (current: Hello)
_
パイプラインモデルの保存と読み込みをテストします。
_from pyspark.sql import Row
def make_a_dataframe(sc):
df = sc.parallelize([Row(name='Alice', age=5, height=80), Row(name='Bob', age=7, height=85), Row(name='Chris', age=10, height=90)]).toDF()
return df
def testSaveAndLoadPipelineModel():
dfA = make_a_dataframe(sc)
estA = createAnEstimator()
#print(estA.explainParams())
pipelineModelA = Pipeline(stages=[estA]).fit(dfA)
savePath = "/whatever/path/you/want"
#pipelineModelA.save(savePath) # Can't overwrite, so...
pipelineModelA.write().overwrite().save(savePath)
pipelineModelReloaded = PysparkPipelineLoader.unwrap(PipelineModel.load(savePath), [MyModel])
modB = pipelineModelReloaded.stages[0]
print(modB.explainParams())
dfB = pipelineModelReloaded.transform(dfA)
dfB.show()
testSaveAndLoadPipelineModel()
_
出力:
_anotherIntParam: A dummy int parameter. (current: 52)
anotherListOfStringsParam: A dummy list of strings. (current: [u'A', u'B', u'C', u'D', u'E', u'F'])
anotherStringParam: A dummy string parameter (current: Hello World!)
+---+------+-----+----+
|age|height| name|age2|
+---+------+-----+----+
| 5| 80|Alice| 57|
| 7| 85| Bob| 59|
| 10| 90|Chris| 62|
+---+------+-----+----+
_
パイプラインまたはパイプラインモデルをアンラップするときは、保存されたパイプラインまたはパイプラインモデルでStopWordsRemover
オブジェクトになりすましている独自のpysparkアルゴリズムに対応するクラスのリストを渡す必要があります。保存したオブジェクトの最後のストップワードは、自分のクラスの名前を識別するために使用され、次にcreateAndInitialisePyObj()
が呼び出されてクラスのインスタンスが作成され、残りのストップワードでパラメーターが初期化されます。
さまざまな改良を加えることができます。しかし、うまくいけば、これにより、 SPARK-17025 が解決されて利用可能になるまで、パイプラインの内側と外側の両方でカスタム推定器とトランスフォーマーを保存およびロードできるようになります。
@dmbakerによる working answer と同様に、Aggregator
というカスタムトランスフォーマーを組み込みのSparkトランスフォーマー、この例では、 Binarizer
、ただし、他のトランスフォーマーからも継承できると確信しています。これにより、カスタムトランスフォーマーはシリアル化に必要なメソッドを継承できました。
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler, Binarizer
from pyspark.ml.regression import LinearRegression
class Aggregator(Binarizer):
"""A huge hack to allow serialization of custom transformer."""
def transform(self, input_df):
agg_df = input_df\
.groupBy('channel_id')\
.agg({
'foo': 'avg',
'bar': 'avg',
})\
.withColumnRenamed('avg(foo)', 'avg_foo')\
.withColumnRenamed('avg(bar)', 'avg_bar')
return agg_df
# Create pipeline stages.
aggregator = Aggregator()
vector_assembler = VectorAssembler(...)
linear_regression = LinearRegression()
# Create pipeline.
pipeline = Pipeline(stages=[aggregator, vector_assembler, linear_regression])
# Train.
pipeline_model = pipeline.fit(input_df)
# Save model file to S3.
pipeline_model.save('s3n://example')
@dmbakerソリューションは私には機能しませんでした。これは、pythonバージョン(2.x対3.x)であるためだと思います。彼のソリューションにいくつかの更新を加えたところ、Python 3 。私の設定は以下のとおりです。
class PysparkObjId(object):
"" "
python
推定器、トランスフォーマー、モデルをシリアル化できるように識別および設定するために使用される定数を指定するクラスそこに
所有し、PiplineまたはPipelineModel内から。
"" "
def init(self):
super(PysparkObjId 、self).init()
@staticmethod
def _getPyObjId():
return '4c1740b00d3c4ff6806a1402321572cb'
@staticmethod
def _getCarrierClass(javaName=False):
return 'org.Apache.spark.ml.feature.StopWordsRemover' if javaName else StopWordsRemover
class PysparkPipelineWrapper(object):
"""
A class to facilitate converting the stages of a Pipeline or PipelineModel
that were saved from PysparkReaderWriter.
"""
def __init__(self):
super(PysparkPipelineWrapper, self).__init__()
@staticmethod
def unwrap(pipeline):
if not (isinstance(pipeline, Pipeline) or isinstance(pipeline, PipelineModel)):
raise TypeError("Cannot recognize a pipeline of type %s." % type(pipeline))
stages = pipeline.getStages() if isinstance(pipeline, Pipeline) else pipeline.stages
for i, stage in enumerate(stages):
if (isinstance(stage, Pipeline) or isinstance(stage, PipelineModel)):
stages[i] = PysparkPipelineWrapper.unwrap(stage)
if isinstance(stage, PysparkObjId._getCarrierClass()) and stage.getStopWords()[-1] == PysparkObjId._getPyObjId():
swords = stage.getStopWords()[:-1] # strip the id
# convert stop words to int
swords = [int(d) for d in swords]
# get the byte value of all ints
lst = [x.to_bytes(length=1, byteorder='big') for x in
swords] # convert from string integer list to bytes
# return the first byte and concatenates all the others
dmp = lst[0]
for byte_counter in range(1, len(lst)):
dmp = dmp + lst[byte_counter]
py_obj = dill.loads(dmp)
stages[i] = py_obj
if isinstance(pipeline, Pipeline):
pipeline.setStages(stages)
else:
pipeline.stages = stages
return pipeline
class PysparkReaderWriter(object):
"""
A mixin class so custom pyspark Estimators, Transformers and Models may
support saving and loading directly or be saved within a Pipline or PipelineModel.
"""
def __init__(self):
super(PysparkReaderWriter, self).__init__()
def write(self):
"""Returns an MLWriter instance for this ML instance."""
return JavaMLWriter(self)
@classmethod
def read(cls):
"""Returns an MLReader instance for our clarrier class."""
return JavaMLReader(PysparkObjId._getCarrierClass())
@classmethod
def load(cls, path):
"""Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
swr_Java_obj = cls.read().load(path)
return cls._from_Java(swr_Java_obj)
@classmethod
def _from_Java(cls, Java_obj):
"""
Get the dumby the stopwords that are the characters of the dills dump plus our guid
and convert, via dill, back to our python instance.
"""
swords = Java_obj.getStopWords()[:-1] # strip the id
lst = [x.to_bytes(length=1, byteorder='big') for x in swords] # convert from string integer list to bytes
dmp = lst[0]
for i in range(1, len(lst)):
dmp = dmp + lst[i]
py_obj = dill.loads(dmp)
return py_obj
def _to_Java(self):
"""
Convert this instance to a dill dump, then to a list of strings with the unicode integer values of each character.
Use this list as a set of dumby stopwords and store in a StopWordsRemover instance
:return: Java object equivalent to this instance.
"""
dmp = dill.dumps(self)
pylist = [str(int(d)) for d in dmp] # convert bytes to string integer list
pylist.append(PysparkObjId._getPyObjId()) # add our id so PysparkPipelineWrapper can id us.
sc = SparkContext._active_spark_context
Java_class = sc._gateway.jvm.Java.lang.String
Java_array = sc._gateway.new_array(Java_class, len(pylist))
for i in range(len(pylist)):
Java_array[i] = pylist[i]
_Java_obj = JavaParams._new_Java_obj(PysparkObjId._getCarrierClass(javaName=True), self.uid)
_Java_obj.setStopWords(Java_array)
return _Java_obj
class HasFake(Params):
def __init__(self):
super(HasFake, self).__init__()
self.fake = Param(self, "fake", "fake param")
def getFake(self):
return self.getOrDefault(self.fake)
class CleanText(Transformer, HasInputCol, HasOutputCol, Identifiable, PysparkReaderWriter, MLReadable, MLWritable):
@keyword_only
def __init__(self, inputCol=None, outputCol=None):
super(CleanText, self).__init__()
kwargs = self._input_kwargs
self.setParams(**kwargs)