如何解决如何利用 Spark 的管道在字符串中查找短语然后添加特征类别?
我想在 pyspark 数据框中的文本列中搜索短语。这是一个示例,向您展示我的意思。
sentenceData = spark.createDataFrame([
(0,"Hi I heard about Spark"),(4,"I wish Java could use case classes"),(11,"Logistic regression models are neat")],["id","sentence"])
如果句子包含“听说过火花”,则 categorySpark=1 和 categoryHeard=1。
如果句子包含“java OR Regression”,则 categoryCool=1。
我有大约 28 个布尔值(或者如果我使用正则表达式可能会更好)来检查。
sentenceData.withColumn('categoryCool',sentenceData['sentence'].rlike('Java | regression')).show()
返回:
+---+--------------------+------------+
| id| sentence|categoryCool|
+---+--------------------+------------+
| 0|Hi I heard about ...| false|
| 4|I wish Java could...| true|
| 11|Logistic regressi...| true|
+---+--------------------+------------+
这就是我想要的,但我想将它添加到管道中作为转换步骤。
解决方法
我找到了这个 nice Medium article 和 this S.O. answer,我将它们结合起来回答了我自己的问题!我希望有一天有人会觉得这很有帮助。
from pyspark.ml.pipeline import Transformer
from pyspark.ml import Pipeline
from pyspark.sql.types import *
from pyspark.ml.util import Identifiable
sentenceData = spark.createDataFrame([
(0,"Hi I heard about Spark"),(4,"I wish Java could use case classes"),(11,"Logistic regression models are neat")
],["id","sentence"])
class OneSearchMultiLabelExtractor(Transformer):
def __init__(self,rlikeSearch,outputCols,inputCol = 'fullText'):
self.inputCol = inputCol
self.outputCols = outputCols
self.rlikeSearch = rlikeSearch
self.uid = str(Identifiable())
def copy(extra):
defaultCopy(extra)
def check_input_type(self,schema):
field = schema[self.inputCol]
if (field.dataType != StringType()):
raise Exception('OneSearchMultiLabelExtractor input type %s did not match input type StringType' % field.dataType)
def check_output_type(self):
if not (isinstance(self.outputCols,list)):
raise Exception('OneSearchMultiLabelExtractor output columns must be a list')
def _transform(self,df):
self.check_input_type(df.schema)
self.check_output_type()
df = df.withColumn("searchResult",df[self.inputCol].rlike(self.rlikeSearch)).cache()
for outputCol in self.outputCols:
df = df.withColumn(outputCol,df["searchResult"])
return df.drop("searchResult")
dex = CoolExtractor(inputCol='sentence',rlikeSearch='Java | regression',outputCols=['coolCategory'])
FeaturesPipeline = Pipeline(stages=[dex])
Featpip = FeaturesPipeline.fit(sentenceData)
Featpip.transform(sentenceData).show()
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。